diff --git a/codex-rs/core/src/apply_patch.rs b/codex-rs/core/src/apply_patch.rs index fcccb40f..4f9292b6 100644 --- a/codex-rs/core/src/apply_patch.rs +++ b/codex-rs/core/src/apply_patch.rs @@ -1,4 +1,5 @@ use crate::codex::Session; +use crate::codex::TurnContext; use crate::models::FunctionCallOutputPayload; use crate::models::ResponseInputItem; use crate::protocol::FileChange; @@ -40,15 +41,16 @@ impl From for InternalApplyPatchInvocation { pub(crate) async fn apply_patch( sess: &Session, + turn_context: &TurnContext, sub_id: &str, call_id: &str, action: ApplyPatchAction, ) -> InternalApplyPatchInvocation { match assess_patch_safety( &action, - sess.get_approval_policy(), - sess.get_sandbox_policy(), - sess.get_cwd(), + turn_context.approval_policy, + &turn_context.sandbox_policy, + &turn_context.cwd, ) { SafetyCheck::AutoApprove { .. } => { InternalApplyPatchInvocation::DelegateToExec(ApplyPatchExec { diff --git a/codex-rs/core/src/client.rs b/codex-rs/core/src/client.rs index 3d26bd08..686ec79d 100644 --- a/codex-rs/core/src/client.rs +++ b/codex-rs/core/src/client.rs @@ -56,7 +56,7 @@ struct Error { message: Option, } -#[derive(Clone)] +#[derive(Debug, Clone)] pub struct ModelClient { config: Arc, auth: Option, diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index f66abe00..b3f87223 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -1,7 +1,6 @@ use std::borrow::Cow; use std::collections::HashMap; use std::collections::HashSet; -use std::path::Path; use std::path::PathBuf; use std::sync::Arc; use std::sync::Mutex; @@ -166,16 +165,17 @@ impl Codex { }; // Generate a unique ID for the lifetime of this Codex session. - let session = Session::new(configure_session, config.clone(), auth, tx_event.clone()) - .await - .map_err(|e| { - error!("Failed to create session: {e:#}"); - CodexErr::InternalAgentDied - })?; + let (session, turn_context) = + Session::new(configure_session, config.clone(), auth, tx_event.clone()) + .await + .map_err(|e| { + error!("Failed to create session: {e:#}"); + CodexErr::InternalAgentDied + })?; let session_id = session.session_id; // This task will run until Op::Shutdown is received. - tokio::spawn(submission_loop(session, config, rx_sub)); + tokio::spawn(submission_loop(session, turn_context, config, rx_sub)); let codex = Codex { next_id: AtomicU64::new(0), tx_sub, @@ -231,21 +231,8 @@ struct State { /// A session has at most 1 running task at a time, and can be interrupted by user input. pub(crate) struct Session { session_id: Uuid, - client: ModelClient, tx_event: Sender, - /// The session's current working directory. All relative paths provided by - /// the model as well as sandbox policies are resolved against this path - /// instead of `std::env::current_dir()`. - cwd: PathBuf, - base_instructions: Option, - user_instructions: Option, - approval_policy: AskForApproval, - sandbox_policy: SandboxPolicy, - shell_environment_policy: ShellEnvironmentPolicy, - disable_response_storage: bool, - tools_config: ToolsConfig, - /// Manager for external MCP servers/tools. mcp_connection_manager: McpConnectionManager, @@ -262,6 +249,31 @@ pub(crate) struct Session { show_raw_agent_reasoning: bool, } +/// The context needed for a single turn of the conversation. +#[derive(Debug)] +pub(crate) struct TurnContext { + pub(crate) client: ModelClient, + /// The session's current working directory. All relative paths provided by + /// the model as well as sandbox policies are resolved against this path + /// instead of `std::env::current_dir()`. + pub(crate) cwd: PathBuf, + pub(crate) base_instructions: Option, + pub(crate) user_instructions: Option, + pub(crate) approval_policy: AskForApproval, + pub(crate) sandbox_policy: SandboxPolicy, + pub(crate) shell_environment_policy: ShellEnvironmentPolicy, + pub(crate) disable_response_storage: bool, + pub(crate) tools_config: ToolsConfig, +} + +impl TurnContext { + fn resolve_path(&self, path: Option) -> PathBuf { + path.as_ref() + .map(PathBuf::from) + .map_or_else(|| self.cwd.clone(), |p| self.cwd.join(p)) + } +} + /// Configure the model session. struct ConfigureSession { /// Provider identifier ("openai", "openrouter", ...). @@ -309,7 +321,7 @@ impl Session { config: Arc, auth: Option, tx_event: Sender, - ) -> anyhow::Result> { + ) -> anyhow::Result<(Arc, TurnContext)> { let ConfigureSession { provider, model, @@ -457,8 +469,7 @@ impl Session { model_reasoning_summary, session_id, ); - let sess = Arc::new(Session { - session_id, + let turn_context = TurnContext { client, tools_config: ToolsConfig::new( &config.model_family, @@ -467,19 +478,22 @@ impl Session { config.include_plan_tool, config.include_apply_patch_tool, ), - tx_event: tx_event.clone(), user_instructions, base_instructions, approval_policy, sandbox_policy, shell_environment_policy: config.shell_environment_policy.clone(), cwd, + disable_response_storage, + }; + let sess = Arc::new(Session { + session_id, + tx_event: tx_event.clone(), mcp_connection_manager, notify, state: Mutex::new(state), rollout: Mutex::new(rollout_recorder), codex_linux_sandbox_exe: config.codex_linux_sandbox_exe.clone(), - disable_response_storage, user_shell: default_shell, show_raw_agent_reasoning: config.show_raw_agent_reasoning, }); @@ -487,13 +501,13 @@ impl Session { // record the initial user instructions and environment context, // regardless of whether we restored items. let mut conversation_items = Vec::::with_capacity(2); - if let Some(user_instructions) = sess.user_instructions.as_deref() { + if let Some(user_instructions) = turn_context.user_instructions.as_deref() { conversation_items.push(Prompt::format_user_instructions_message(user_instructions)); } conversation_items.push(ResponseItem::from(EnvironmentContext::new( - sess.get_cwd().to_path_buf(), - sess.get_approval_policy(), - sess.sandbox_policy.clone(), + turn_context.cwd.to_path_buf(), + turn_context.approval_policy, + turn_context.sandbox_policy.clone(), ))); sess.record_conversation_items(&conversation_items).await; @@ -514,25 +528,7 @@ impl Session { } } - Ok(sess) - } - - pub(crate) fn get_approval_policy(&self) -> AskForApproval { - self.approval_policy - } - - pub(crate) fn get_sandbox_policy(&self) -> &SandboxPolicy { - &self.sandbox_policy - } - - pub(crate) fn get_cwd(&self) -> &Path { - &self.cwd - } - - fn resolve_path(&self, path: Option) -> PathBuf { - path.as_ref() - .map(PathBuf::from) - .map_or_else(|| self.cwd.clone(), |p| self.cwd.join(p)) + Ok((sess, turn_context)) } pub fn set_task(&self, task: AgentTask) { @@ -921,9 +917,19 @@ pub(crate) struct AgentTask { } impl AgentTask { - fn spawn(sess: Arc, sub_id: String, input: Vec) -> Self { - let handle = - tokio::spawn(run_task(Arc::clone(&sess), sub_id.clone(), input)).abort_handle(); + fn spawn( + sess: Arc, + turn_context: Arc, + sub_id: String, + input: Vec, + ) -> Self { + let handle = { + let sess = sess.clone(); + let sub_id = sub_id.clone(); + let tc = Arc::clone(&turn_context); + tokio::spawn(async move { run_task(sess, tc.as_ref(), sub_id, input).await }) + .abort_handle() + }; Self { sess, sub_id, @@ -933,17 +939,20 @@ impl AgentTask { fn compact( sess: Arc, + turn_context: Arc, sub_id: String, input: Vec, compact_instructions: String, ) -> Self { - let handle = tokio::spawn(run_compact_task( - Arc::clone(&sess), - sub_id.clone(), - input, - compact_instructions, - )) - .abort_handle(); + let handle = { + let sess = sess.clone(); + let sub_id = sub_id.clone(); + let tc = Arc::clone(&turn_context); + tokio::spawn(async move { + run_compact_task(sess, tc.as_ref(), sub_id, input, compact_instructions).await + }) + .abort_handle() + }; Self { sess, sub_id, @@ -968,7 +977,14 @@ impl AgentTask { } } -async fn submission_loop(sess: Arc, config: Arc, rx_sub: Receiver) { +async fn submission_loop( + sess: Arc, + turn_context: TurnContext, + config: Arc, + rx_sub: Receiver, +) { + // Wrap once to avoid cloning TurnContext for each task. + let turn_context = Arc::new(turn_context); // To break out of this loop, send Op::Shutdown. while let Ok(sub) = rx_sub.recv().await { debug!(?sub, "Submission"); @@ -980,7 +996,8 @@ async fn submission_loop(sess: Arc, config: Arc, rx_sub: Receiv // attempt to inject input into current task if let Err(items) = sess.inject_input(items) { // no current task, spawn a new one - let task = AgentTask::spawn(sess.clone(), sub.id, items); + let task = + AgentTask::spawn(sess.clone(), Arc::clone(&turn_context), sub.id, items); sess.set_task(task); } } @@ -1046,6 +1063,7 @@ async fn submission_loop(sess: Arc, config: Arc, rx_sub: Receiv }]) { let task = AgentTask::compact( sess.clone(), + Arc::clone(&turn_context), sub.id, items, SUMMARIZATION_PROMPT.to_string(), @@ -1101,7 +1119,12 @@ async fn submission_loop(sess: Arc, config: Arc, rx_sub: Receiv /// back to the model in the next turn. /// - If the model sends only an assistant message, we record it in the /// conversation history and consider the task complete. -async fn run_task(sess: Arc, sub_id: String, input: Vec) { +async fn run_task( + sess: Arc, + turn_context: &TurnContext, + sub_id: String, + input: Vec, +) { if input.is_empty() { return; } @@ -1153,7 +1176,15 @@ async fn run_task(sess: Arc, sub_id: String, input: Vec) { }) }) .collect(); - match run_turn(&sess, &mut turn_diff_tracker, sub_id.clone(), turn_input).await { + match run_turn( + &sess, + turn_context, + &mut turn_diff_tracker, + sub_id.clone(), + turn_input, + ) + .await + { Ok(turn_output) => { let mut items_to_record_in_conversation_history = Vec::::new(); let mut responses = Vec::::new(); @@ -1282,25 +1313,26 @@ async fn run_task(sess: Arc, sub_id: String, input: Vec) { async fn run_turn( sess: &Session, + turn_context: &TurnContext, turn_diff_tracker: &mut TurnDiffTracker, sub_id: String, input: Vec, ) -> CodexResult> { let tools = get_openai_tools( - &sess.tools_config, + &turn_context.tools_config, Some(sess.mcp_connection_manager.list_all_tools()), ); let prompt = Prompt { input, - store: !sess.disable_response_storage, + store: !turn_context.disable_response_storage, tools, - base_instructions_override: sess.base_instructions.clone(), + base_instructions_override: turn_context.base_instructions.clone(), }; let mut retries = 0; loop { - match try_run_turn(sess, turn_diff_tracker, &sub_id, &prompt).await { + match try_run_turn(sess, turn_context, turn_diff_tracker, &sub_id, &prompt).await { Ok(output) => return Ok(output), Err(CodexErr::Interrupted) => return Err(CodexErr::Interrupted), Err(CodexErr::EnvVar(var)) => return Err(CodexErr::EnvVar(var)), @@ -1309,7 +1341,7 @@ async fn run_turn( } Err(e) => { // Use the configured provider-specific stream retry budget. - let max_retries = sess.client.get_provider().stream_max_retries(); + let max_retries = turn_context.client.get_provider().stream_max_retries(); if retries < max_retries { retries += 1; let delay = match e { @@ -1352,6 +1384,7 @@ struct ProcessedResponseItem { async fn try_run_turn( sess: &Session, + turn_context: &TurnContext, turn_diff_tracker: &mut TurnDiffTracker, sub_id: &str, prompt: &Prompt, @@ -1412,7 +1445,7 @@ async fn try_run_turn( }) }; - let mut stream = sess.client.clone().stream(&prompt).await?; + let mut stream = turn_context.client.clone().stream(&prompt).await?; let mut output = Vec::new(); loop { @@ -1441,9 +1474,14 @@ async fn try_run_turn( match event { ResponseEvent::Created => {} ResponseEvent::OutputItemDone(item) => { - let response = - handle_response_item(sess, turn_diff_tracker, sub_id, item.clone()).await?; - + let response = handle_response_item( + sess, + turn_context, + turn_diff_tracker, + sub_id, + item.clone(), + ) + .await?; output.push(ProcessedResponseItem { item, response }); } ResponseEvent::Completed { @@ -1515,6 +1553,7 @@ async fn try_run_turn( async fn run_compact_task( sess: Arc, + turn_context: &TurnContext, sub_id: String, input: Vec, compact_instructions: String, @@ -1533,16 +1572,16 @@ async fn run_compact_task( let prompt = Prompt { input: turn_input, - store: !sess.disable_response_storage, + store: !turn_context.disable_response_storage, tools: Vec::new(), base_instructions_override: Some(compact_instructions.clone()), }; - let max_retries = sess.client.get_provider().stream_max_retries(); + let max_retries = turn_context.client.get_provider().stream_max_retries(); let mut retries = 0; loop { - let attempt_result = drain_to_completed(&sess, &sub_id, &prompt).await; + let attempt_result = drain_to_completed(&sess, turn_context, &sub_id, &prompt).await; match attempt_result { Ok(()) => break, @@ -1596,6 +1635,7 @@ async fn run_compact_task( async fn handle_response_item( sess: &Session, + turn_context: &TurnContext, turn_diff_tracker: &mut TurnDiffTracker, sub_id: &str, item: ResponseItem, @@ -1659,6 +1699,7 @@ async fn handle_response_item( Some( handle_function_call( sess, + turn_context, turn_diff_tracker, sub_id.to_string(), name, @@ -1698,11 +1739,12 @@ async fn handle_response_item( } }; - let exec_params = to_exec_params(params, sess); + let exec_params = to_exec_params(params, turn_context); Some( handle_container_exec_with_params( exec_params, sess, + turn_context, turn_diff_tracker, sub_id.to_string(), effective_call_id, @@ -1721,6 +1763,7 @@ async fn handle_response_item( async fn handle_function_call( sess: &Session, + turn_context: &TurnContext, turn_diff_tracker: &mut TurnDiffTracker, sub_id: String, name: String, @@ -1729,14 +1772,21 @@ async fn handle_function_call( ) -> ResponseInputItem { match name.as_str() { "container.exec" | "shell" => { - let params = match parse_container_exec_arguments(arguments, sess, &call_id) { + let params = match parse_container_exec_arguments(arguments, turn_context, &call_id) { Ok(params) => params, Err(output) => { return *output; } }; - handle_container_exec_with_params(params, sess, turn_diff_tracker, sub_id, call_id) - .await + handle_container_exec_with_params( + params, + sess, + turn_context, + turn_diff_tracker, + sub_id, + call_id, + ) + .await } "apply_patch" => { let args = match serde_json::from_str::(&arguments) { @@ -1753,14 +1803,21 @@ async fn handle_function_call( }; let exec_params = ExecParams { command: vec!["apply_patch".to_string(), args.input.clone()], - cwd: sess.cwd.clone(), + cwd: turn_context.cwd.clone(), timeout_ms: None, env: HashMap::new(), with_escalated_permissions: None, justification: None, }; - handle_container_exec_with_params(exec_params, sess, turn_diff_tracker, sub_id, call_id) - .await + handle_container_exec_with_params( + exec_params, + sess, + turn_context, + turn_diff_tracker, + sub_id, + call_id, + ) + .await } "update_plan" => handle_update_plan(sess, arguments, sub_id, call_id).await, _ => { @@ -1788,12 +1845,12 @@ async fn handle_function_call( } } -fn to_exec_params(params: ShellToolCallParams, sess: &Session) -> ExecParams { +fn to_exec_params(params: ShellToolCallParams, turn_context: &TurnContext) -> ExecParams { ExecParams { command: params.command, - cwd: sess.resolve_path(params.workdir.clone()), + cwd: turn_context.resolve_path(params.workdir.clone()), timeout_ms: params.timeout_ms, - env: create_env(&sess.shell_environment_policy), + env: create_env(&turn_context.shell_environment_policy), with_escalated_permissions: params.with_escalated_permissions, justification: params.justification, } @@ -1801,12 +1858,12 @@ fn to_exec_params(params: ShellToolCallParams, sess: &Session) -> ExecParams { fn parse_container_exec_arguments( arguments: String, - sess: &Session, + turn_context: &TurnContext, call_id: &str, ) -> Result> { // parse command match serde_json::from_str::(&arguments) { - Ok(shell_tool_call_params) => Ok(to_exec_params(shell_tool_call_params, sess)), + Ok(shell_tool_call_params) => Ok(to_exec_params(shell_tool_call_params, turn_context)), Err(e) => { // allow model to re-sample let output = ResponseInputItem::FunctionCallOutput { @@ -1829,8 +1886,12 @@ pub struct ExecInvokeArgs<'a> { pub stdout_stream: Option, } -fn maybe_run_with_user_profile(params: ExecParams, sess: &Session) -> ExecParams { - if sess.shell_environment_policy.use_profile { +fn maybe_run_with_user_profile( + params: ExecParams, + sess: &Session, + turn_context: &TurnContext, +) -> ExecParams { + if turn_context.shell_environment_policy.use_profile { let command = sess .user_shell .format_default_shell_invocation(params.command.clone()); @@ -1844,6 +1905,7 @@ fn maybe_run_with_user_profile(params: ExecParams, sess: &Session) -> ExecParams async fn handle_container_exec_with_params( params: ExecParams, sess: &Session, + turn_context: &TurnContext, turn_diff_tracker: &mut TurnDiffTracker, sub_id: String, call_id: String, @@ -1851,7 +1913,7 @@ async fn handle_container_exec_with_params( // check if this was a patch, and apply it if so let apply_patch_exec = match maybe_parse_apply_patch_verified(¶ms.command, ¶ms.cwd) { MaybeApplyPatchVerified::Body(changes) => { - match apply_patch::apply_patch(sess, &sub_id, &call_id, changes).await { + match apply_patch::apply_patch(sess, turn_context, &sub_id, &call_id, changes).await { InternalApplyPatchInvocation::Output(item) => return item, InternalApplyPatchInvocation::DelegateToExec(apply_patch_exec) => { Some(apply_patch_exec) @@ -1913,8 +1975,8 @@ async fn handle_container_exec_with_params( } } else { assess_safety_for_untrusted_command( - sess.approval_policy, - &sess.sandbox_policy, + turn_context.approval_policy, + &turn_context.sandbox_policy, params.with_escalated_permissions.unwrap_or(false), ) }; @@ -1929,8 +1991,8 @@ async fn handle_container_exec_with_params( let state = sess.state.lock_unchecked(); assess_command_safety( ¶ms.command, - sess.approval_policy, - &sess.sandbox_policy, + turn_context.approval_policy, + &turn_context.sandbox_policy, &state.approved_commands, params.with_escalated_permissions.unwrap_or(false), ) @@ -2000,7 +2062,7 @@ async fn handle_container_exec_with_params( ), }; - let params = maybe_run_with_user_profile(params, sess); + let params = maybe_run_with_user_profile(params, sess, turn_context); let output_result = sess .run_exec_with_events( turn_diff_tracker, @@ -2008,7 +2070,7 @@ async fn handle_container_exec_with_params( ExecInvokeArgs { params: params.clone(), sandbox_type, - sandbox_policy: &sess.sandbox_policy, + sandbox_policy: &turn_context.sandbox_policy, codex_linux_sandbox_exe: &sess.codex_linux_sandbox_exe, stdout_stream: Some(StdoutStream { sub_id: sub_id.clone(), @@ -2041,6 +2103,7 @@ async fn handle_container_exec_with_params( error, sandbox_type, sess, + turn_context, ) .await } @@ -2061,6 +2124,7 @@ async fn handle_sandbox_error( error: SandboxErr, sandbox_type: SandboxType, sess: &Session, + turn_context: &TurnContext, ) -> ResponseInputItem { let call_id = exec_command_context.call_id.clone(); let sub_id = exec_command_context.sub_id.clone(); @@ -2068,7 +2132,7 @@ async fn handle_sandbox_error( // Early out if either the user never wants to be asked for approval, or // we're letting the model manage escalation requests. Otherwise, continue - match sess.approval_policy { + match turn_context.approval_policy { AskForApproval::Never | AskForApproval::OnRequest => { return ResponseInputItem::FunctionCallOutput { call_id, @@ -2139,7 +2203,7 @@ async fn handle_sandbox_error( ExecInvokeArgs { params, sandbox_type: SandboxType::None, - sandbox_policy: &sess.sandbox_policy, + sandbox_policy: &turn_context.sandbox_policy, codex_linux_sandbox_exe: &sess.codex_linux_sandbox_exe, stdout_stream: Some(StdoutStream { sub_id: sub_id.clone(), @@ -2253,8 +2317,13 @@ fn get_last_assistant_message_from_turn(responses: &[ResponseItem]) -> Option CodexResult<()> { - let mut stream = sess.client.clone().stream(prompt).await?; +async fn drain_to_completed( + sess: &Session, + turn_context: &TurnContext, + sub_id: &str, + prompt: &Prompt, +) -> CodexResult<()> { + let mut stream = turn_context.client.clone().stream(prompt).await?; loop { let maybe_event = stream.next().await; let Some(event) = maybe_event else {