feat: introduce TurnContext (#2343)

This PR introduces `TurnContext`, which is designed to hold a set of
fields that should be constant for a turn of a conversation. Note that
the fields of `TurnContext` were previously governed by `Session`.

Ultimately, we want to enable users to change these values between turns
(changing model, approval policy, etc.), though in the current
implementation, the `TurnContext` is constant for the entire
conversation.

---
[//]: # (BEGIN SAPLING FOOTER)
Stack created with [Sapling](https://sapling-scm.com). Best reviewed
with [ReviewStack](https://reviewstack.dev/openai/codex/pull/2345).
* #2345
* #2329
* __->__ #2343
* #2340
* #2338
This commit is contained in:
Michael Bolin
2025-08-15 09:40:02 -07:00
committed by GitHub
parent 45d6c74682
commit 13ed67cfc1
3 changed files with 175 additions and 104 deletions

View File

@@ -1,4 +1,5 @@
use crate::codex::Session;
use crate::codex::TurnContext;
use crate::models::FunctionCallOutputPayload;
use crate::models::ResponseInputItem;
use crate::protocol::FileChange;
@@ -40,15 +41,16 @@ impl From<ResponseInputItem> for InternalApplyPatchInvocation {
pub(crate) async fn apply_patch(
sess: &Session,
turn_context: &TurnContext,
sub_id: &str,
call_id: &str,
action: ApplyPatchAction,
) -> InternalApplyPatchInvocation {
match assess_patch_safety(
&action,
sess.get_approval_policy(),
sess.get_sandbox_policy(),
sess.get_cwd(),
turn_context.approval_policy,
&turn_context.sandbox_policy,
&turn_context.cwd,
) {
SafetyCheck::AutoApprove { .. } => {
InternalApplyPatchInvocation::DelegateToExec(ApplyPatchExec {

View File

@@ -56,7 +56,7 @@ struct Error {
message: Option<String>,
}
#[derive(Clone)]
#[derive(Debug, Clone)]
pub struct ModelClient {
config: Arc<Config>,
auth: Option<CodexAuth>,

View File

@@ -1,7 +1,6 @@
use std::borrow::Cow;
use std::collections::HashMap;
use std::collections::HashSet;
use std::path::Path;
use std::path::PathBuf;
use std::sync::Arc;
use std::sync::Mutex;
@@ -166,16 +165,17 @@ impl Codex {
};
// Generate a unique ID for the lifetime of this Codex session.
let session = Session::new(configure_session, config.clone(), auth, tx_event.clone())
.await
.map_err(|e| {
error!("Failed to create session: {e:#}");
CodexErr::InternalAgentDied
})?;
let (session, turn_context) =
Session::new(configure_session, config.clone(), auth, tx_event.clone())
.await
.map_err(|e| {
error!("Failed to create session: {e:#}");
CodexErr::InternalAgentDied
})?;
let session_id = session.session_id;
// This task will run until Op::Shutdown is received.
tokio::spawn(submission_loop(session, config, rx_sub));
tokio::spawn(submission_loop(session, turn_context, config, rx_sub));
let codex = Codex {
next_id: AtomicU64::new(0),
tx_sub,
@@ -231,21 +231,8 @@ struct State {
/// A session has at most 1 running task at a time, and can be interrupted by user input.
pub(crate) struct Session {
session_id: Uuid,
client: ModelClient,
tx_event: Sender<Event>,
/// The session's current working directory. All relative paths provided by
/// the model as well as sandbox policies are resolved against this path
/// instead of `std::env::current_dir()`.
cwd: PathBuf,
base_instructions: Option<String>,
user_instructions: Option<String>,
approval_policy: AskForApproval,
sandbox_policy: SandboxPolicy,
shell_environment_policy: ShellEnvironmentPolicy,
disable_response_storage: bool,
tools_config: ToolsConfig,
/// Manager for external MCP servers/tools.
mcp_connection_manager: McpConnectionManager,
@@ -262,6 +249,31 @@ pub(crate) struct Session {
show_raw_agent_reasoning: bool,
}
/// The context needed for a single turn of the conversation.
#[derive(Debug)]
pub(crate) struct TurnContext {
pub(crate) client: ModelClient,
/// The session's current working directory. All relative paths provided by
/// the model as well as sandbox policies are resolved against this path
/// instead of `std::env::current_dir()`.
pub(crate) cwd: PathBuf,
pub(crate) base_instructions: Option<String>,
pub(crate) user_instructions: Option<String>,
pub(crate) approval_policy: AskForApproval,
pub(crate) sandbox_policy: SandboxPolicy,
pub(crate) shell_environment_policy: ShellEnvironmentPolicy,
pub(crate) disable_response_storage: bool,
pub(crate) tools_config: ToolsConfig,
}
impl TurnContext {
fn resolve_path(&self, path: Option<String>) -> PathBuf {
path.as_ref()
.map(PathBuf::from)
.map_or_else(|| self.cwd.clone(), |p| self.cwd.join(p))
}
}
/// Configure the model session.
struct ConfigureSession {
/// Provider identifier ("openai", "openrouter", ...).
@@ -309,7 +321,7 @@ impl Session {
config: Arc<Config>,
auth: Option<CodexAuth>,
tx_event: Sender<Event>,
) -> anyhow::Result<Arc<Self>> {
) -> anyhow::Result<(Arc<Self>, TurnContext)> {
let ConfigureSession {
provider,
model,
@@ -457,8 +469,7 @@ impl Session {
model_reasoning_summary,
session_id,
);
let sess = Arc::new(Session {
session_id,
let turn_context = TurnContext {
client,
tools_config: ToolsConfig::new(
&config.model_family,
@@ -467,19 +478,22 @@ impl Session {
config.include_plan_tool,
config.include_apply_patch_tool,
),
tx_event: tx_event.clone(),
user_instructions,
base_instructions,
approval_policy,
sandbox_policy,
shell_environment_policy: config.shell_environment_policy.clone(),
cwd,
disable_response_storage,
};
let sess = Arc::new(Session {
session_id,
tx_event: tx_event.clone(),
mcp_connection_manager,
notify,
state: Mutex::new(state),
rollout: Mutex::new(rollout_recorder),
codex_linux_sandbox_exe: config.codex_linux_sandbox_exe.clone(),
disable_response_storage,
user_shell: default_shell,
show_raw_agent_reasoning: config.show_raw_agent_reasoning,
});
@@ -487,13 +501,13 @@ impl Session {
// record the initial user instructions and environment context,
// regardless of whether we restored items.
let mut conversation_items = Vec::<ResponseItem>::with_capacity(2);
if let Some(user_instructions) = sess.user_instructions.as_deref() {
if let Some(user_instructions) = turn_context.user_instructions.as_deref() {
conversation_items.push(Prompt::format_user_instructions_message(user_instructions));
}
conversation_items.push(ResponseItem::from(EnvironmentContext::new(
sess.get_cwd().to_path_buf(),
sess.get_approval_policy(),
sess.sandbox_policy.clone(),
turn_context.cwd.to_path_buf(),
turn_context.approval_policy,
turn_context.sandbox_policy.clone(),
)));
sess.record_conversation_items(&conversation_items).await;
@@ -514,25 +528,7 @@ impl Session {
}
}
Ok(sess)
}
pub(crate) fn get_approval_policy(&self) -> AskForApproval {
self.approval_policy
}
pub(crate) fn get_sandbox_policy(&self) -> &SandboxPolicy {
&self.sandbox_policy
}
pub(crate) fn get_cwd(&self) -> &Path {
&self.cwd
}
fn resolve_path(&self, path: Option<String>) -> PathBuf {
path.as_ref()
.map(PathBuf::from)
.map_or_else(|| self.cwd.clone(), |p| self.cwd.join(p))
Ok((sess, turn_context))
}
pub fn set_task(&self, task: AgentTask) {
@@ -921,9 +917,19 @@ pub(crate) struct AgentTask {
}
impl AgentTask {
fn spawn(sess: Arc<Session>, sub_id: String, input: Vec<InputItem>) -> Self {
let handle =
tokio::spawn(run_task(Arc::clone(&sess), sub_id.clone(), input)).abort_handle();
fn spawn(
sess: Arc<Session>,
turn_context: Arc<TurnContext>,
sub_id: String,
input: Vec<InputItem>,
) -> Self {
let handle = {
let sess = sess.clone();
let sub_id = sub_id.clone();
let tc = Arc::clone(&turn_context);
tokio::spawn(async move { run_task(sess, tc.as_ref(), sub_id, input).await })
.abort_handle()
};
Self {
sess,
sub_id,
@@ -933,17 +939,20 @@ impl AgentTask {
fn compact(
sess: Arc<Session>,
turn_context: Arc<TurnContext>,
sub_id: String,
input: Vec<InputItem>,
compact_instructions: String,
) -> Self {
let handle = tokio::spawn(run_compact_task(
Arc::clone(&sess),
sub_id.clone(),
input,
compact_instructions,
))
.abort_handle();
let handle = {
let sess = sess.clone();
let sub_id = sub_id.clone();
let tc = Arc::clone(&turn_context);
tokio::spawn(async move {
run_compact_task(sess, tc.as_ref(), sub_id, input, compact_instructions).await
})
.abort_handle()
};
Self {
sess,
sub_id,
@@ -968,7 +977,14 @@ impl AgentTask {
}
}
async fn submission_loop(sess: Arc<Session>, config: Arc<Config>, rx_sub: Receiver<Submission>) {
async fn submission_loop(
sess: Arc<Session>,
turn_context: TurnContext,
config: Arc<Config>,
rx_sub: Receiver<Submission>,
) {
// Wrap once to avoid cloning TurnContext for each task.
let turn_context = Arc::new(turn_context);
// To break out of this loop, send Op::Shutdown.
while let Ok(sub) = rx_sub.recv().await {
debug!(?sub, "Submission");
@@ -980,7 +996,8 @@ async fn submission_loop(sess: Arc<Session>, config: Arc<Config>, rx_sub: Receiv
// attempt to inject input into current task
if let Err(items) = sess.inject_input(items) {
// no current task, spawn a new one
let task = AgentTask::spawn(sess.clone(), sub.id, items);
let task =
AgentTask::spawn(sess.clone(), Arc::clone(&turn_context), sub.id, items);
sess.set_task(task);
}
}
@@ -1046,6 +1063,7 @@ async fn submission_loop(sess: Arc<Session>, config: Arc<Config>, rx_sub: Receiv
}]) {
let task = AgentTask::compact(
sess.clone(),
Arc::clone(&turn_context),
sub.id,
items,
SUMMARIZATION_PROMPT.to_string(),
@@ -1101,7 +1119,12 @@ async fn submission_loop(sess: Arc<Session>, config: Arc<Config>, rx_sub: Receiv
/// back to the model in the next turn.
/// - If the model sends only an assistant message, we record it in the
/// conversation history and consider the task complete.
async fn run_task(sess: Arc<Session>, sub_id: String, input: Vec<InputItem>) {
async fn run_task(
sess: Arc<Session>,
turn_context: &TurnContext,
sub_id: String,
input: Vec<InputItem>,
) {
if input.is_empty() {
return;
}
@@ -1153,7 +1176,15 @@ async fn run_task(sess: Arc<Session>, sub_id: String, input: Vec<InputItem>) {
})
})
.collect();
match run_turn(&sess, &mut turn_diff_tracker, sub_id.clone(), turn_input).await {
match run_turn(
&sess,
turn_context,
&mut turn_diff_tracker,
sub_id.clone(),
turn_input,
)
.await
{
Ok(turn_output) => {
let mut items_to_record_in_conversation_history = Vec::<ResponseItem>::new();
let mut responses = Vec::<ResponseInputItem>::new();
@@ -1282,25 +1313,26 @@ async fn run_task(sess: Arc<Session>, sub_id: String, input: Vec<InputItem>) {
async fn run_turn(
sess: &Session,
turn_context: &TurnContext,
turn_diff_tracker: &mut TurnDiffTracker,
sub_id: String,
input: Vec<ResponseItem>,
) -> CodexResult<Vec<ProcessedResponseItem>> {
let tools = get_openai_tools(
&sess.tools_config,
&turn_context.tools_config,
Some(sess.mcp_connection_manager.list_all_tools()),
);
let prompt = Prompt {
input,
store: !sess.disable_response_storage,
store: !turn_context.disable_response_storage,
tools,
base_instructions_override: sess.base_instructions.clone(),
base_instructions_override: turn_context.base_instructions.clone(),
};
let mut retries = 0;
loop {
match try_run_turn(sess, turn_diff_tracker, &sub_id, &prompt).await {
match try_run_turn(sess, turn_context, turn_diff_tracker, &sub_id, &prompt).await {
Ok(output) => return Ok(output),
Err(CodexErr::Interrupted) => return Err(CodexErr::Interrupted),
Err(CodexErr::EnvVar(var)) => return Err(CodexErr::EnvVar(var)),
@@ -1309,7 +1341,7 @@ async fn run_turn(
}
Err(e) => {
// Use the configured provider-specific stream retry budget.
let max_retries = sess.client.get_provider().stream_max_retries();
let max_retries = turn_context.client.get_provider().stream_max_retries();
if retries < max_retries {
retries += 1;
let delay = match e {
@@ -1352,6 +1384,7 @@ struct ProcessedResponseItem {
async fn try_run_turn(
sess: &Session,
turn_context: &TurnContext,
turn_diff_tracker: &mut TurnDiffTracker,
sub_id: &str,
prompt: &Prompt,
@@ -1412,7 +1445,7 @@ async fn try_run_turn(
})
};
let mut stream = sess.client.clone().stream(&prompt).await?;
let mut stream = turn_context.client.clone().stream(&prompt).await?;
let mut output = Vec::new();
loop {
@@ -1441,9 +1474,14 @@ async fn try_run_turn(
match event {
ResponseEvent::Created => {}
ResponseEvent::OutputItemDone(item) => {
let response =
handle_response_item(sess, turn_diff_tracker, sub_id, item.clone()).await?;
let response = handle_response_item(
sess,
turn_context,
turn_diff_tracker,
sub_id,
item.clone(),
)
.await?;
output.push(ProcessedResponseItem { item, response });
}
ResponseEvent::Completed {
@@ -1515,6 +1553,7 @@ async fn try_run_turn(
async fn run_compact_task(
sess: Arc<Session>,
turn_context: &TurnContext,
sub_id: String,
input: Vec<InputItem>,
compact_instructions: String,
@@ -1533,16 +1572,16 @@ async fn run_compact_task(
let prompt = Prompt {
input: turn_input,
store: !sess.disable_response_storage,
store: !turn_context.disable_response_storage,
tools: Vec::new(),
base_instructions_override: Some(compact_instructions.clone()),
};
let max_retries = sess.client.get_provider().stream_max_retries();
let max_retries = turn_context.client.get_provider().stream_max_retries();
let mut retries = 0;
loop {
let attempt_result = drain_to_completed(&sess, &sub_id, &prompt).await;
let attempt_result = drain_to_completed(&sess, turn_context, &sub_id, &prompt).await;
match attempt_result {
Ok(()) => break,
@@ -1596,6 +1635,7 @@ async fn run_compact_task(
async fn handle_response_item(
sess: &Session,
turn_context: &TurnContext,
turn_diff_tracker: &mut TurnDiffTracker,
sub_id: &str,
item: ResponseItem,
@@ -1659,6 +1699,7 @@ async fn handle_response_item(
Some(
handle_function_call(
sess,
turn_context,
turn_diff_tracker,
sub_id.to_string(),
name,
@@ -1698,11 +1739,12 @@ async fn handle_response_item(
}
};
let exec_params = to_exec_params(params, sess);
let exec_params = to_exec_params(params, turn_context);
Some(
handle_container_exec_with_params(
exec_params,
sess,
turn_context,
turn_diff_tracker,
sub_id.to_string(),
effective_call_id,
@@ -1721,6 +1763,7 @@ async fn handle_response_item(
async fn handle_function_call(
sess: &Session,
turn_context: &TurnContext,
turn_diff_tracker: &mut TurnDiffTracker,
sub_id: String,
name: String,
@@ -1729,14 +1772,21 @@ async fn handle_function_call(
) -> ResponseInputItem {
match name.as_str() {
"container.exec" | "shell" => {
let params = match parse_container_exec_arguments(arguments, sess, &call_id) {
let params = match parse_container_exec_arguments(arguments, turn_context, &call_id) {
Ok(params) => params,
Err(output) => {
return *output;
}
};
handle_container_exec_with_params(params, sess, turn_diff_tracker, sub_id, call_id)
.await
handle_container_exec_with_params(
params,
sess,
turn_context,
turn_diff_tracker,
sub_id,
call_id,
)
.await
}
"apply_patch" => {
let args = match serde_json::from_str::<ApplyPatchToolArgs>(&arguments) {
@@ -1753,14 +1803,21 @@ async fn handle_function_call(
};
let exec_params = ExecParams {
command: vec!["apply_patch".to_string(), args.input.clone()],
cwd: sess.cwd.clone(),
cwd: turn_context.cwd.clone(),
timeout_ms: None,
env: HashMap::new(),
with_escalated_permissions: None,
justification: None,
};
handle_container_exec_with_params(exec_params, sess, turn_diff_tracker, sub_id, call_id)
.await
handle_container_exec_with_params(
exec_params,
sess,
turn_context,
turn_diff_tracker,
sub_id,
call_id,
)
.await
}
"update_plan" => handle_update_plan(sess, arguments, sub_id, call_id).await,
_ => {
@@ -1788,12 +1845,12 @@ async fn handle_function_call(
}
}
fn to_exec_params(params: ShellToolCallParams, sess: &Session) -> ExecParams {
fn to_exec_params(params: ShellToolCallParams, turn_context: &TurnContext) -> ExecParams {
ExecParams {
command: params.command,
cwd: sess.resolve_path(params.workdir.clone()),
cwd: turn_context.resolve_path(params.workdir.clone()),
timeout_ms: params.timeout_ms,
env: create_env(&sess.shell_environment_policy),
env: create_env(&turn_context.shell_environment_policy),
with_escalated_permissions: params.with_escalated_permissions,
justification: params.justification,
}
@@ -1801,12 +1858,12 @@ fn to_exec_params(params: ShellToolCallParams, sess: &Session) -> ExecParams {
fn parse_container_exec_arguments(
arguments: String,
sess: &Session,
turn_context: &TurnContext,
call_id: &str,
) -> Result<ExecParams, Box<ResponseInputItem>> {
// parse command
match serde_json::from_str::<ShellToolCallParams>(&arguments) {
Ok(shell_tool_call_params) => Ok(to_exec_params(shell_tool_call_params, sess)),
Ok(shell_tool_call_params) => Ok(to_exec_params(shell_tool_call_params, turn_context)),
Err(e) => {
// allow model to re-sample
let output = ResponseInputItem::FunctionCallOutput {
@@ -1829,8 +1886,12 @@ pub struct ExecInvokeArgs<'a> {
pub stdout_stream: Option<StdoutStream>,
}
fn maybe_run_with_user_profile(params: ExecParams, sess: &Session) -> ExecParams {
if sess.shell_environment_policy.use_profile {
fn maybe_run_with_user_profile(
params: ExecParams,
sess: &Session,
turn_context: &TurnContext,
) -> ExecParams {
if turn_context.shell_environment_policy.use_profile {
let command = sess
.user_shell
.format_default_shell_invocation(params.command.clone());
@@ -1844,6 +1905,7 @@ fn maybe_run_with_user_profile(params: ExecParams, sess: &Session) -> ExecParams
async fn handle_container_exec_with_params(
params: ExecParams,
sess: &Session,
turn_context: &TurnContext,
turn_diff_tracker: &mut TurnDiffTracker,
sub_id: String,
call_id: String,
@@ -1851,7 +1913,7 @@ async fn handle_container_exec_with_params(
// check if this was a patch, and apply it if so
let apply_patch_exec = match maybe_parse_apply_patch_verified(&params.command, &params.cwd) {
MaybeApplyPatchVerified::Body(changes) => {
match apply_patch::apply_patch(sess, &sub_id, &call_id, changes).await {
match apply_patch::apply_patch(sess, turn_context, &sub_id, &call_id, changes).await {
InternalApplyPatchInvocation::Output(item) => return item,
InternalApplyPatchInvocation::DelegateToExec(apply_patch_exec) => {
Some(apply_patch_exec)
@@ -1913,8 +1975,8 @@ async fn handle_container_exec_with_params(
}
} else {
assess_safety_for_untrusted_command(
sess.approval_policy,
&sess.sandbox_policy,
turn_context.approval_policy,
&turn_context.sandbox_policy,
params.with_escalated_permissions.unwrap_or(false),
)
};
@@ -1929,8 +1991,8 @@ async fn handle_container_exec_with_params(
let state = sess.state.lock_unchecked();
assess_command_safety(
&params.command,
sess.approval_policy,
&sess.sandbox_policy,
turn_context.approval_policy,
&turn_context.sandbox_policy,
&state.approved_commands,
params.with_escalated_permissions.unwrap_or(false),
)
@@ -2000,7 +2062,7 @@ async fn handle_container_exec_with_params(
),
};
let params = maybe_run_with_user_profile(params, sess);
let params = maybe_run_with_user_profile(params, sess, turn_context);
let output_result = sess
.run_exec_with_events(
turn_diff_tracker,
@@ -2008,7 +2070,7 @@ async fn handle_container_exec_with_params(
ExecInvokeArgs {
params: params.clone(),
sandbox_type,
sandbox_policy: &sess.sandbox_policy,
sandbox_policy: &turn_context.sandbox_policy,
codex_linux_sandbox_exe: &sess.codex_linux_sandbox_exe,
stdout_stream: Some(StdoutStream {
sub_id: sub_id.clone(),
@@ -2041,6 +2103,7 @@ async fn handle_container_exec_with_params(
error,
sandbox_type,
sess,
turn_context,
)
.await
}
@@ -2061,6 +2124,7 @@ async fn handle_sandbox_error(
error: SandboxErr,
sandbox_type: SandboxType,
sess: &Session,
turn_context: &TurnContext,
) -> ResponseInputItem {
let call_id = exec_command_context.call_id.clone();
let sub_id = exec_command_context.sub_id.clone();
@@ -2068,7 +2132,7 @@ async fn handle_sandbox_error(
// Early out if either the user never wants to be asked for approval, or
// we're letting the model manage escalation requests. Otherwise, continue
match sess.approval_policy {
match turn_context.approval_policy {
AskForApproval::Never | AskForApproval::OnRequest => {
return ResponseInputItem::FunctionCallOutput {
call_id,
@@ -2139,7 +2203,7 @@ async fn handle_sandbox_error(
ExecInvokeArgs {
params,
sandbox_type: SandboxType::None,
sandbox_policy: &sess.sandbox_policy,
sandbox_policy: &turn_context.sandbox_policy,
codex_linux_sandbox_exe: &sess.codex_linux_sandbox_exe,
stdout_stream: Some(StdoutStream {
sub_id: sub_id.clone(),
@@ -2253,8 +2317,13 @@ fn get_last_assistant_message_from_turn(responses: &[ResponseItem]) -> Option<St
})
}
async fn drain_to_completed(sess: &Session, sub_id: &str, prompt: &Prompt) -> CodexResult<()> {
let mut stream = sess.client.clone().stream(prompt).await?;
async fn drain_to_completed(
sess: &Session,
turn_context: &TurnContext,
sub_id: &str,
prompt: &Prompt,
) -> CodexResult<()> {
let mut stream = turn_context.client.clone().stream(prompt).await?;
loop {
let maybe_event = stream.next().await;
let Some(event) = maybe_event else {