feat: introduce TurnContext (#2343)
This PR introduces `TurnContext`, which is designed to hold a set of fields that should be constant for a turn of a conversation. Note that the fields of `TurnContext` were previously governed by `Session`. Ultimately, we want to enable users to change these values between turns (changing model, approval policy, etc.), though in the current implementation, the `TurnContext` is constant for the entire conversation. --- [//]: # (BEGIN SAPLING FOOTER) Stack created with [Sapling](https://sapling-scm.com). Best reviewed with [ReviewStack](https://reviewstack.dev/openai/codex/pull/2345). * #2345 * #2329 * __->__ #2343 * #2340 * #2338
This commit is contained in:
@@ -1,4 +1,5 @@
|
|||||||
use crate::codex::Session;
|
use crate::codex::Session;
|
||||||
|
use crate::codex::TurnContext;
|
||||||
use crate::models::FunctionCallOutputPayload;
|
use crate::models::FunctionCallOutputPayload;
|
||||||
use crate::models::ResponseInputItem;
|
use crate::models::ResponseInputItem;
|
||||||
use crate::protocol::FileChange;
|
use crate::protocol::FileChange;
|
||||||
@@ -40,15 +41,16 @@ impl From<ResponseInputItem> for InternalApplyPatchInvocation {
|
|||||||
|
|
||||||
pub(crate) async fn apply_patch(
|
pub(crate) async fn apply_patch(
|
||||||
sess: &Session,
|
sess: &Session,
|
||||||
|
turn_context: &TurnContext,
|
||||||
sub_id: &str,
|
sub_id: &str,
|
||||||
call_id: &str,
|
call_id: &str,
|
||||||
action: ApplyPatchAction,
|
action: ApplyPatchAction,
|
||||||
) -> InternalApplyPatchInvocation {
|
) -> InternalApplyPatchInvocation {
|
||||||
match assess_patch_safety(
|
match assess_patch_safety(
|
||||||
&action,
|
&action,
|
||||||
sess.get_approval_policy(),
|
turn_context.approval_policy,
|
||||||
sess.get_sandbox_policy(),
|
&turn_context.sandbox_policy,
|
||||||
sess.get_cwd(),
|
&turn_context.cwd,
|
||||||
) {
|
) {
|
||||||
SafetyCheck::AutoApprove { .. } => {
|
SafetyCheck::AutoApprove { .. } => {
|
||||||
InternalApplyPatchInvocation::DelegateToExec(ApplyPatchExec {
|
InternalApplyPatchInvocation::DelegateToExec(ApplyPatchExec {
|
||||||
|
|||||||
@@ -56,7 +56,7 @@ struct Error {
|
|||||||
message: Option<String>,
|
message: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct ModelClient {
|
pub struct ModelClient {
|
||||||
config: Arc<Config>,
|
config: Arc<Config>,
|
||||||
auth: Option<CodexAuth>,
|
auth: Option<CodexAuth>,
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
use std::borrow::Cow;
|
use std::borrow::Cow;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::collections::HashSet;
|
use std::collections::HashSet;
|
||||||
use std::path::Path;
|
|
||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::sync::Mutex;
|
use std::sync::Mutex;
|
||||||
@@ -166,16 +165,17 @@ impl Codex {
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Generate a unique ID for the lifetime of this Codex session.
|
// Generate a unique ID for the lifetime of this Codex session.
|
||||||
let session = Session::new(configure_session, config.clone(), auth, tx_event.clone())
|
let (session, turn_context) =
|
||||||
.await
|
Session::new(configure_session, config.clone(), auth, tx_event.clone())
|
||||||
.map_err(|e| {
|
.await
|
||||||
error!("Failed to create session: {e:#}");
|
.map_err(|e| {
|
||||||
CodexErr::InternalAgentDied
|
error!("Failed to create session: {e:#}");
|
||||||
})?;
|
CodexErr::InternalAgentDied
|
||||||
|
})?;
|
||||||
let session_id = session.session_id;
|
let session_id = session.session_id;
|
||||||
|
|
||||||
// This task will run until Op::Shutdown is received.
|
// This task will run until Op::Shutdown is received.
|
||||||
tokio::spawn(submission_loop(session, config, rx_sub));
|
tokio::spawn(submission_loop(session, turn_context, config, rx_sub));
|
||||||
let codex = Codex {
|
let codex = Codex {
|
||||||
next_id: AtomicU64::new(0),
|
next_id: AtomicU64::new(0),
|
||||||
tx_sub,
|
tx_sub,
|
||||||
@@ -231,21 +231,8 @@ struct State {
|
|||||||
/// A session has at most 1 running task at a time, and can be interrupted by user input.
|
/// A session has at most 1 running task at a time, and can be interrupted by user input.
|
||||||
pub(crate) struct Session {
|
pub(crate) struct Session {
|
||||||
session_id: Uuid,
|
session_id: Uuid,
|
||||||
client: ModelClient,
|
|
||||||
tx_event: Sender<Event>,
|
tx_event: Sender<Event>,
|
||||||
|
|
||||||
/// The session's current working directory. All relative paths provided by
|
|
||||||
/// the model as well as sandbox policies are resolved against this path
|
|
||||||
/// instead of `std::env::current_dir()`.
|
|
||||||
cwd: PathBuf,
|
|
||||||
base_instructions: Option<String>,
|
|
||||||
user_instructions: Option<String>,
|
|
||||||
approval_policy: AskForApproval,
|
|
||||||
sandbox_policy: SandboxPolicy,
|
|
||||||
shell_environment_policy: ShellEnvironmentPolicy,
|
|
||||||
disable_response_storage: bool,
|
|
||||||
tools_config: ToolsConfig,
|
|
||||||
|
|
||||||
/// Manager for external MCP servers/tools.
|
/// Manager for external MCP servers/tools.
|
||||||
mcp_connection_manager: McpConnectionManager,
|
mcp_connection_manager: McpConnectionManager,
|
||||||
|
|
||||||
@@ -262,6 +249,31 @@ pub(crate) struct Session {
|
|||||||
show_raw_agent_reasoning: bool,
|
show_raw_agent_reasoning: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// The context needed for a single turn of the conversation.
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub(crate) struct TurnContext {
|
||||||
|
pub(crate) client: ModelClient,
|
||||||
|
/// The session's current working directory. All relative paths provided by
|
||||||
|
/// the model as well as sandbox policies are resolved against this path
|
||||||
|
/// instead of `std::env::current_dir()`.
|
||||||
|
pub(crate) cwd: PathBuf,
|
||||||
|
pub(crate) base_instructions: Option<String>,
|
||||||
|
pub(crate) user_instructions: Option<String>,
|
||||||
|
pub(crate) approval_policy: AskForApproval,
|
||||||
|
pub(crate) sandbox_policy: SandboxPolicy,
|
||||||
|
pub(crate) shell_environment_policy: ShellEnvironmentPolicy,
|
||||||
|
pub(crate) disable_response_storage: bool,
|
||||||
|
pub(crate) tools_config: ToolsConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TurnContext {
|
||||||
|
fn resolve_path(&self, path: Option<String>) -> PathBuf {
|
||||||
|
path.as_ref()
|
||||||
|
.map(PathBuf::from)
|
||||||
|
.map_or_else(|| self.cwd.clone(), |p| self.cwd.join(p))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Configure the model session.
|
/// Configure the model session.
|
||||||
struct ConfigureSession {
|
struct ConfigureSession {
|
||||||
/// Provider identifier ("openai", "openrouter", ...).
|
/// Provider identifier ("openai", "openrouter", ...).
|
||||||
@@ -309,7 +321,7 @@ impl Session {
|
|||||||
config: Arc<Config>,
|
config: Arc<Config>,
|
||||||
auth: Option<CodexAuth>,
|
auth: Option<CodexAuth>,
|
||||||
tx_event: Sender<Event>,
|
tx_event: Sender<Event>,
|
||||||
) -> anyhow::Result<Arc<Self>> {
|
) -> anyhow::Result<(Arc<Self>, TurnContext)> {
|
||||||
let ConfigureSession {
|
let ConfigureSession {
|
||||||
provider,
|
provider,
|
||||||
model,
|
model,
|
||||||
@@ -457,8 +469,7 @@ impl Session {
|
|||||||
model_reasoning_summary,
|
model_reasoning_summary,
|
||||||
session_id,
|
session_id,
|
||||||
);
|
);
|
||||||
let sess = Arc::new(Session {
|
let turn_context = TurnContext {
|
||||||
session_id,
|
|
||||||
client,
|
client,
|
||||||
tools_config: ToolsConfig::new(
|
tools_config: ToolsConfig::new(
|
||||||
&config.model_family,
|
&config.model_family,
|
||||||
@@ -467,19 +478,22 @@ impl Session {
|
|||||||
config.include_plan_tool,
|
config.include_plan_tool,
|
||||||
config.include_apply_patch_tool,
|
config.include_apply_patch_tool,
|
||||||
),
|
),
|
||||||
tx_event: tx_event.clone(),
|
|
||||||
user_instructions,
|
user_instructions,
|
||||||
base_instructions,
|
base_instructions,
|
||||||
approval_policy,
|
approval_policy,
|
||||||
sandbox_policy,
|
sandbox_policy,
|
||||||
shell_environment_policy: config.shell_environment_policy.clone(),
|
shell_environment_policy: config.shell_environment_policy.clone(),
|
||||||
cwd,
|
cwd,
|
||||||
|
disable_response_storage,
|
||||||
|
};
|
||||||
|
let sess = Arc::new(Session {
|
||||||
|
session_id,
|
||||||
|
tx_event: tx_event.clone(),
|
||||||
mcp_connection_manager,
|
mcp_connection_manager,
|
||||||
notify,
|
notify,
|
||||||
state: Mutex::new(state),
|
state: Mutex::new(state),
|
||||||
rollout: Mutex::new(rollout_recorder),
|
rollout: Mutex::new(rollout_recorder),
|
||||||
codex_linux_sandbox_exe: config.codex_linux_sandbox_exe.clone(),
|
codex_linux_sandbox_exe: config.codex_linux_sandbox_exe.clone(),
|
||||||
disable_response_storage,
|
|
||||||
user_shell: default_shell,
|
user_shell: default_shell,
|
||||||
show_raw_agent_reasoning: config.show_raw_agent_reasoning,
|
show_raw_agent_reasoning: config.show_raw_agent_reasoning,
|
||||||
});
|
});
|
||||||
@@ -487,13 +501,13 @@ impl Session {
|
|||||||
// record the initial user instructions and environment context,
|
// record the initial user instructions and environment context,
|
||||||
// regardless of whether we restored items.
|
// regardless of whether we restored items.
|
||||||
let mut conversation_items = Vec::<ResponseItem>::with_capacity(2);
|
let mut conversation_items = Vec::<ResponseItem>::with_capacity(2);
|
||||||
if let Some(user_instructions) = sess.user_instructions.as_deref() {
|
if let Some(user_instructions) = turn_context.user_instructions.as_deref() {
|
||||||
conversation_items.push(Prompt::format_user_instructions_message(user_instructions));
|
conversation_items.push(Prompt::format_user_instructions_message(user_instructions));
|
||||||
}
|
}
|
||||||
conversation_items.push(ResponseItem::from(EnvironmentContext::new(
|
conversation_items.push(ResponseItem::from(EnvironmentContext::new(
|
||||||
sess.get_cwd().to_path_buf(),
|
turn_context.cwd.to_path_buf(),
|
||||||
sess.get_approval_policy(),
|
turn_context.approval_policy,
|
||||||
sess.sandbox_policy.clone(),
|
turn_context.sandbox_policy.clone(),
|
||||||
)));
|
)));
|
||||||
sess.record_conversation_items(&conversation_items).await;
|
sess.record_conversation_items(&conversation_items).await;
|
||||||
|
|
||||||
@@ -514,25 +528,7 @@ impl Session {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(sess)
|
Ok((sess, turn_context))
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn get_approval_policy(&self) -> AskForApproval {
|
|
||||||
self.approval_policy
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn get_sandbox_policy(&self) -> &SandboxPolicy {
|
|
||||||
&self.sandbox_policy
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn get_cwd(&self) -> &Path {
|
|
||||||
&self.cwd
|
|
||||||
}
|
|
||||||
|
|
||||||
fn resolve_path(&self, path: Option<String>) -> PathBuf {
|
|
||||||
path.as_ref()
|
|
||||||
.map(PathBuf::from)
|
|
||||||
.map_or_else(|| self.cwd.clone(), |p| self.cwd.join(p))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn set_task(&self, task: AgentTask) {
|
pub fn set_task(&self, task: AgentTask) {
|
||||||
@@ -921,9 +917,19 @@ pub(crate) struct AgentTask {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl AgentTask {
|
impl AgentTask {
|
||||||
fn spawn(sess: Arc<Session>, sub_id: String, input: Vec<InputItem>) -> Self {
|
fn spawn(
|
||||||
let handle =
|
sess: Arc<Session>,
|
||||||
tokio::spawn(run_task(Arc::clone(&sess), sub_id.clone(), input)).abort_handle();
|
turn_context: Arc<TurnContext>,
|
||||||
|
sub_id: String,
|
||||||
|
input: Vec<InputItem>,
|
||||||
|
) -> Self {
|
||||||
|
let handle = {
|
||||||
|
let sess = sess.clone();
|
||||||
|
let sub_id = sub_id.clone();
|
||||||
|
let tc = Arc::clone(&turn_context);
|
||||||
|
tokio::spawn(async move { run_task(sess, tc.as_ref(), sub_id, input).await })
|
||||||
|
.abort_handle()
|
||||||
|
};
|
||||||
Self {
|
Self {
|
||||||
sess,
|
sess,
|
||||||
sub_id,
|
sub_id,
|
||||||
@@ -933,17 +939,20 @@ impl AgentTask {
|
|||||||
|
|
||||||
fn compact(
|
fn compact(
|
||||||
sess: Arc<Session>,
|
sess: Arc<Session>,
|
||||||
|
turn_context: Arc<TurnContext>,
|
||||||
sub_id: String,
|
sub_id: String,
|
||||||
input: Vec<InputItem>,
|
input: Vec<InputItem>,
|
||||||
compact_instructions: String,
|
compact_instructions: String,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
let handle = tokio::spawn(run_compact_task(
|
let handle = {
|
||||||
Arc::clone(&sess),
|
let sess = sess.clone();
|
||||||
sub_id.clone(),
|
let sub_id = sub_id.clone();
|
||||||
input,
|
let tc = Arc::clone(&turn_context);
|
||||||
compact_instructions,
|
tokio::spawn(async move {
|
||||||
))
|
run_compact_task(sess, tc.as_ref(), sub_id, input, compact_instructions).await
|
||||||
.abort_handle();
|
})
|
||||||
|
.abort_handle()
|
||||||
|
};
|
||||||
Self {
|
Self {
|
||||||
sess,
|
sess,
|
||||||
sub_id,
|
sub_id,
|
||||||
@@ -968,7 +977,14 @@ impl AgentTask {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn submission_loop(sess: Arc<Session>, config: Arc<Config>, rx_sub: Receiver<Submission>) {
|
async fn submission_loop(
|
||||||
|
sess: Arc<Session>,
|
||||||
|
turn_context: TurnContext,
|
||||||
|
config: Arc<Config>,
|
||||||
|
rx_sub: Receiver<Submission>,
|
||||||
|
) {
|
||||||
|
// Wrap once to avoid cloning TurnContext for each task.
|
||||||
|
let turn_context = Arc::new(turn_context);
|
||||||
// To break out of this loop, send Op::Shutdown.
|
// To break out of this loop, send Op::Shutdown.
|
||||||
while let Ok(sub) = rx_sub.recv().await {
|
while let Ok(sub) = rx_sub.recv().await {
|
||||||
debug!(?sub, "Submission");
|
debug!(?sub, "Submission");
|
||||||
@@ -980,7 +996,8 @@ async fn submission_loop(sess: Arc<Session>, config: Arc<Config>, rx_sub: Receiv
|
|||||||
// attempt to inject input into current task
|
// attempt to inject input into current task
|
||||||
if let Err(items) = sess.inject_input(items) {
|
if let Err(items) = sess.inject_input(items) {
|
||||||
// no current task, spawn a new one
|
// no current task, spawn a new one
|
||||||
let task = AgentTask::spawn(sess.clone(), sub.id, items);
|
let task =
|
||||||
|
AgentTask::spawn(sess.clone(), Arc::clone(&turn_context), sub.id, items);
|
||||||
sess.set_task(task);
|
sess.set_task(task);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1046,6 +1063,7 @@ async fn submission_loop(sess: Arc<Session>, config: Arc<Config>, rx_sub: Receiv
|
|||||||
}]) {
|
}]) {
|
||||||
let task = AgentTask::compact(
|
let task = AgentTask::compact(
|
||||||
sess.clone(),
|
sess.clone(),
|
||||||
|
Arc::clone(&turn_context),
|
||||||
sub.id,
|
sub.id,
|
||||||
items,
|
items,
|
||||||
SUMMARIZATION_PROMPT.to_string(),
|
SUMMARIZATION_PROMPT.to_string(),
|
||||||
@@ -1101,7 +1119,12 @@ async fn submission_loop(sess: Arc<Session>, config: Arc<Config>, rx_sub: Receiv
|
|||||||
/// back to the model in the next turn.
|
/// back to the model in the next turn.
|
||||||
/// - If the model sends only an assistant message, we record it in the
|
/// - If the model sends only an assistant message, we record it in the
|
||||||
/// conversation history and consider the task complete.
|
/// conversation history and consider the task complete.
|
||||||
async fn run_task(sess: Arc<Session>, sub_id: String, input: Vec<InputItem>) {
|
async fn run_task(
|
||||||
|
sess: Arc<Session>,
|
||||||
|
turn_context: &TurnContext,
|
||||||
|
sub_id: String,
|
||||||
|
input: Vec<InputItem>,
|
||||||
|
) {
|
||||||
if input.is_empty() {
|
if input.is_empty() {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@@ -1153,7 +1176,15 @@ async fn run_task(sess: Arc<Session>, sub_id: String, input: Vec<InputItem>) {
|
|||||||
})
|
})
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
match run_turn(&sess, &mut turn_diff_tracker, sub_id.clone(), turn_input).await {
|
match run_turn(
|
||||||
|
&sess,
|
||||||
|
turn_context,
|
||||||
|
&mut turn_diff_tracker,
|
||||||
|
sub_id.clone(),
|
||||||
|
turn_input,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
{
|
||||||
Ok(turn_output) => {
|
Ok(turn_output) => {
|
||||||
let mut items_to_record_in_conversation_history = Vec::<ResponseItem>::new();
|
let mut items_to_record_in_conversation_history = Vec::<ResponseItem>::new();
|
||||||
let mut responses = Vec::<ResponseInputItem>::new();
|
let mut responses = Vec::<ResponseInputItem>::new();
|
||||||
@@ -1282,25 +1313,26 @@ async fn run_task(sess: Arc<Session>, sub_id: String, input: Vec<InputItem>) {
|
|||||||
|
|
||||||
async fn run_turn(
|
async fn run_turn(
|
||||||
sess: &Session,
|
sess: &Session,
|
||||||
|
turn_context: &TurnContext,
|
||||||
turn_diff_tracker: &mut TurnDiffTracker,
|
turn_diff_tracker: &mut TurnDiffTracker,
|
||||||
sub_id: String,
|
sub_id: String,
|
||||||
input: Vec<ResponseItem>,
|
input: Vec<ResponseItem>,
|
||||||
) -> CodexResult<Vec<ProcessedResponseItem>> {
|
) -> CodexResult<Vec<ProcessedResponseItem>> {
|
||||||
let tools = get_openai_tools(
|
let tools = get_openai_tools(
|
||||||
&sess.tools_config,
|
&turn_context.tools_config,
|
||||||
Some(sess.mcp_connection_manager.list_all_tools()),
|
Some(sess.mcp_connection_manager.list_all_tools()),
|
||||||
);
|
);
|
||||||
|
|
||||||
let prompt = Prompt {
|
let prompt = Prompt {
|
||||||
input,
|
input,
|
||||||
store: !sess.disable_response_storage,
|
store: !turn_context.disable_response_storage,
|
||||||
tools,
|
tools,
|
||||||
base_instructions_override: sess.base_instructions.clone(),
|
base_instructions_override: turn_context.base_instructions.clone(),
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut retries = 0;
|
let mut retries = 0;
|
||||||
loop {
|
loop {
|
||||||
match try_run_turn(sess, turn_diff_tracker, &sub_id, &prompt).await {
|
match try_run_turn(sess, turn_context, turn_diff_tracker, &sub_id, &prompt).await {
|
||||||
Ok(output) => return Ok(output),
|
Ok(output) => return Ok(output),
|
||||||
Err(CodexErr::Interrupted) => return Err(CodexErr::Interrupted),
|
Err(CodexErr::Interrupted) => return Err(CodexErr::Interrupted),
|
||||||
Err(CodexErr::EnvVar(var)) => return Err(CodexErr::EnvVar(var)),
|
Err(CodexErr::EnvVar(var)) => return Err(CodexErr::EnvVar(var)),
|
||||||
@@ -1309,7 +1341,7 @@ async fn run_turn(
|
|||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
// Use the configured provider-specific stream retry budget.
|
// Use the configured provider-specific stream retry budget.
|
||||||
let max_retries = sess.client.get_provider().stream_max_retries();
|
let max_retries = turn_context.client.get_provider().stream_max_retries();
|
||||||
if retries < max_retries {
|
if retries < max_retries {
|
||||||
retries += 1;
|
retries += 1;
|
||||||
let delay = match e {
|
let delay = match e {
|
||||||
@@ -1352,6 +1384,7 @@ struct ProcessedResponseItem {
|
|||||||
|
|
||||||
async fn try_run_turn(
|
async fn try_run_turn(
|
||||||
sess: &Session,
|
sess: &Session,
|
||||||
|
turn_context: &TurnContext,
|
||||||
turn_diff_tracker: &mut TurnDiffTracker,
|
turn_diff_tracker: &mut TurnDiffTracker,
|
||||||
sub_id: &str,
|
sub_id: &str,
|
||||||
prompt: &Prompt,
|
prompt: &Prompt,
|
||||||
@@ -1412,7 +1445,7 @@ async fn try_run_turn(
|
|||||||
})
|
})
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut stream = sess.client.clone().stream(&prompt).await?;
|
let mut stream = turn_context.client.clone().stream(&prompt).await?;
|
||||||
|
|
||||||
let mut output = Vec::new();
|
let mut output = Vec::new();
|
||||||
loop {
|
loop {
|
||||||
@@ -1441,9 +1474,14 @@ async fn try_run_turn(
|
|||||||
match event {
|
match event {
|
||||||
ResponseEvent::Created => {}
|
ResponseEvent::Created => {}
|
||||||
ResponseEvent::OutputItemDone(item) => {
|
ResponseEvent::OutputItemDone(item) => {
|
||||||
let response =
|
let response = handle_response_item(
|
||||||
handle_response_item(sess, turn_diff_tracker, sub_id, item.clone()).await?;
|
sess,
|
||||||
|
turn_context,
|
||||||
|
turn_diff_tracker,
|
||||||
|
sub_id,
|
||||||
|
item.clone(),
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
output.push(ProcessedResponseItem { item, response });
|
output.push(ProcessedResponseItem { item, response });
|
||||||
}
|
}
|
||||||
ResponseEvent::Completed {
|
ResponseEvent::Completed {
|
||||||
@@ -1515,6 +1553,7 @@ async fn try_run_turn(
|
|||||||
|
|
||||||
async fn run_compact_task(
|
async fn run_compact_task(
|
||||||
sess: Arc<Session>,
|
sess: Arc<Session>,
|
||||||
|
turn_context: &TurnContext,
|
||||||
sub_id: String,
|
sub_id: String,
|
||||||
input: Vec<InputItem>,
|
input: Vec<InputItem>,
|
||||||
compact_instructions: String,
|
compact_instructions: String,
|
||||||
@@ -1533,16 +1572,16 @@ async fn run_compact_task(
|
|||||||
|
|
||||||
let prompt = Prompt {
|
let prompt = Prompt {
|
||||||
input: turn_input,
|
input: turn_input,
|
||||||
store: !sess.disable_response_storage,
|
store: !turn_context.disable_response_storage,
|
||||||
tools: Vec::new(),
|
tools: Vec::new(),
|
||||||
base_instructions_override: Some(compact_instructions.clone()),
|
base_instructions_override: Some(compact_instructions.clone()),
|
||||||
};
|
};
|
||||||
|
|
||||||
let max_retries = sess.client.get_provider().stream_max_retries();
|
let max_retries = turn_context.client.get_provider().stream_max_retries();
|
||||||
let mut retries = 0;
|
let mut retries = 0;
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
let attempt_result = drain_to_completed(&sess, &sub_id, &prompt).await;
|
let attempt_result = drain_to_completed(&sess, turn_context, &sub_id, &prompt).await;
|
||||||
|
|
||||||
match attempt_result {
|
match attempt_result {
|
||||||
Ok(()) => break,
|
Ok(()) => break,
|
||||||
@@ -1596,6 +1635,7 @@ async fn run_compact_task(
|
|||||||
|
|
||||||
async fn handle_response_item(
|
async fn handle_response_item(
|
||||||
sess: &Session,
|
sess: &Session,
|
||||||
|
turn_context: &TurnContext,
|
||||||
turn_diff_tracker: &mut TurnDiffTracker,
|
turn_diff_tracker: &mut TurnDiffTracker,
|
||||||
sub_id: &str,
|
sub_id: &str,
|
||||||
item: ResponseItem,
|
item: ResponseItem,
|
||||||
@@ -1659,6 +1699,7 @@ async fn handle_response_item(
|
|||||||
Some(
|
Some(
|
||||||
handle_function_call(
|
handle_function_call(
|
||||||
sess,
|
sess,
|
||||||
|
turn_context,
|
||||||
turn_diff_tracker,
|
turn_diff_tracker,
|
||||||
sub_id.to_string(),
|
sub_id.to_string(),
|
||||||
name,
|
name,
|
||||||
@@ -1698,11 +1739,12 @@ async fn handle_response_item(
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let exec_params = to_exec_params(params, sess);
|
let exec_params = to_exec_params(params, turn_context);
|
||||||
Some(
|
Some(
|
||||||
handle_container_exec_with_params(
|
handle_container_exec_with_params(
|
||||||
exec_params,
|
exec_params,
|
||||||
sess,
|
sess,
|
||||||
|
turn_context,
|
||||||
turn_diff_tracker,
|
turn_diff_tracker,
|
||||||
sub_id.to_string(),
|
sub_id.to_string(),
|
||||||
effective_call_id,
|
effective_call_id,
|
||||||
@@ -1721,6 +1763,7 @@ async fn handle_response_item(
|
|||||||
|
|
||||||
async fn handle_function_call(
|
async fn handle_function_call(
|
||||||
sess: &Session,
|
sess: &Session,
|
||||||
|
turn_context: &TurnContext,
|
||||||
turn_diff_tracker: &mut TurnDiffTracker,
|
turn_diff_tracker: &mut TurnDiffTracker,
|
||||||
sub_id: String,
|
sub_id: String,
|
||||||
name: String,
|
name: String,
|
||||||
@@ -1729,14 +1772,21 @@ async fn handle_function_call(
|
|||||||
) -> ResponseInputItem {
|
) -> ResponseInputItem {
|
||||||
match name.as_str() {
|
match name.as_str() {
|
||||||
"container.exec" | "shell" => {
|
"container.exec" | "shell" => {
|
||||||
let params = match parse_container_exec_arguments(arguments, sess, &call_id) {
|
let params = match parse_container_exec_arguments(arguments, turn_context, &call_id) {
|
||||||
Ok(params) => params,
|
Ok(params) => params,
|
||||||
Err(output) => {
|
Err(output) => {
|
||||||
return *output;
|
return *output;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
handle_container_exec_with_params(params, sess, turn_diff_tracker, sub_id, call_id)
|
handle_container_exec_with_params(
|
||||||
.await
|
params,
|
||||||
|
sess,
|
||||||
|
turn_context,
|
||||||
|
turn_diff_tracker,
|
||||||
|
sub_id,
|
||||||
|
call_id,
|
||||||
|
)
|
||||||
|
.await
|
||||||
}
|
}
|
||||||
"apply_patch" => {
|
"apply_patch" => {
|
||||||
let args = match serde_json::from_str::<ApplyPatchToolArgs>(&arguments) {
|
let args = match serde_json::from_str::<ApplyPatchToolArgs>(&arguments) {
|
||||||
@@ -1753,14 +1803,21 @@ async fn handle_function_call(
|
|||||||
};
|
};
|
||||||
let exec_params = ExecParams {
|
let exec_params = ExecParams {
|
||||||
command: vec!["apply_patch".to_string(), args.input.clone()],
|
command: vec!["apply_patch".to_string(), args.input.clone()],
|
||||||
cwd: sess.cwd.clone(),
|
cwd: turn_context.cwd.clone(),
|
||||||
timeout_ms: None,
|
timeout_ms: None,
|
||||||
env: HashMap::new(),
|
env: HashMap::new(),
|
||||||
with_escalated_permissions: None,
|
with_escalated_permissions: None,
|
||||||
justification: None,
|
justification: None,
|
||||||
};
|
};
|
||||||
handle_container_exec_with_params(exec_params, sess, turn_diff_tracker, sub_id, call_id)
|
handle_container_exec_with_params(
|
||||||
.await
|
exec_params,
|
||||||
|
sess,
|
||||||
|
turn_context,
|
||||||
|
turn_diff_tracker,
|
||||||
|
sub_id,
|
||||||
|
call_id,
|
||||||
|
)
|
||||||
|
.await
|
||||||
}
|
}
|
||||||
"update_plan" => handle_update_plan(sess, arguments, sub_id, call_id).await,
|
"update_plan" => handle_update_plan(sess, arguments, sub_id, call_id).await,
|
||||||
_ => {
|
_ => {
|
||||||
@@ -1788,12 +1845,12 @@ async fn handle_function_call(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn to_exec_params(params: ShellToolCallParams, sess: &Session) -> ExecParams {
|
fn to_exec_params(params: ShellToolCallParams, turn_context: &TurnContext) -> ExecParams {
|
||||||
ExecParams {
|
ExecParams {
|
||||||
command: params.command,
|
command: params.command,
|
||||||
cwd: sess.resolve_path(params.workdir.clone()),
|
cwd: turn_context.resolve_path(params.workdir.clone()),
|
||||||
timeout_ms: params.timeout_ms,
|
timeout_ms: params.timeout_ms,
|
||||||
env: create_env(&sess.shell_environment_policy),
|
env: create_env(&turn_context.shell_environment_policy),
|
||||||
with_escalated_permissions: params.with_escalated_permissions,
|
with_escalated_permissions: params.with_escalated_permissions,
|
||||||
justification: params.justification,
|
justification: params.justification,
|
||||||
}
|
}
|
||||||
@@ -1801,12 +1858,12 @@ fn to_exec_params(params: ShellToolCallParams, sess: &Session) -> ExecParams {
|
|||||||
|
|
||||||
fn parse_container_exec_arguments(
|
fn parse_container_exec_arguments(
|
||||||
arguments: String,
|
arguments: String,
|
||||||
sess: &Session,
|
turn_context: &TurnContext,
|
||||||
call_id: &str,
|
call_id: &str,
|
||||||
) -> Result<ExecParams, Box<ResponseInputItem>> {
|
) -> Result<ExecParams, Box<ResponseInputItem>> {
|
||||||
// parse command
|
// parse command
|
||||||
match serde_json::from_str::<ShellToolCallParams>(&arguments) {
|
match serde_json::from_str::<ShellToolCallParams>(&arguments) {
|
||||||
Ok(shell_tool_call_params) => Ok(to_exec_params(shell_tool_call_params, sess)),
|
Ok(shell_tool_call_params) => Ok(to_exec_params(shell_tool_call_params, turn_context)),
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
// allow model to re-sample
|
// allow model to re-sample
|
||||||
let output = ResponseInputItem::FunctionCallOutput {
|
let output = ResponseInputItem::FunctionCallOutput {
|
||||||
@@ -1829,8 +1886,12 @@ pub struct ExecInvokeArgs<'a> {
|
|||||||
pub stdout_stream: Option<StdoutStream>,
|
pub stdout_stream: Option<StdoutStream>,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn maybe_run_with_user_profile(params: ExecParams, sess: &Session) -> ExecParams {
|
fn maybe_run_with_user_profile(
|
||||||
if sess.shell_environment_policy.use_profile {
|
params: ExecParams,
|
||||||
|
sess: &Session,
|
||||||
|
turn_context: &TurnContext,
|
||||||
|
) -> ExecParams {
|
||||||
|
if turn_context.shell_environment_policy.use_profile {
|
||||||
let command = sess
|
let command = sess
|
||||||
.user_shell
|
.user_shell
|
||||||
.format_default_shell_invocation(params.command.clone());
|
.format_default_shell_invocation(params.command.clone());
|
||||||
@@ -1844,6 +1905,7 @@ fn maybe_run_with_user_profile(params: ExecParams, sess: &Session) -> ExecParams
|
|||||||
async fn handle_container_exec_with_params(
|
async fn handle_container_exec_with_params(
|
||||||
params: ExecParams,
|
params: ExecParams,
|
||||||
sess: &Session,
|
sess: &Session,
|
||||||
|
turn_context: &TurnContext,
|
||||||
turn_diff_tracker: &mut TurnDiffTracker,
|
turn_diff_tracker: &mut TurnDiffTracker,
|
||||||
sub_id: String,
|
sub_id: String,
|
||||||
call_id: String,
|
call_id: String,
|
||||||
@@ -1851,7 +1913,7 @@ async fn handle_container_exec_with_params(
|
|||||||
// check if this was a patch, and apply it if so
|
// check if this was a patch, and apply it if so
|
||||||
let apply_patch_exec = match maybe_parse_apply_patch_verified(¶ms.command, ¶ms.cwd) {
|
let apply_patch_exec = match maybe_parse_apply_patch_verified(¶ms.command, ¶ms.cwd) {
|
||||||
MaybeApplyPatchVerified::Body(changes) => {
|
MaybeApplyPatchVerified::Body(changes) => {
|
||||||
match apply_patch::apply_patch(sess, &sub_id, &call_id, changes).await {
|
match apply_patch::apply_patch(sess, turn_context, &sub_id, &call_id, changes).await {
|
||||||
InternalApplyPatchInvocation::Output(item) => return item,
|
InternalApplyPatchInvocation::Output(item) => return item,
|
||||||
InternalApplyPatchInvocation::DelegateToExec(apply_patch_exec) => {
|
InternalApplyPatchInvocation::DelegateToExec(apply_patch_exec) => {
|
||||||
Some(apply_patch_exec)
|
Some(apply_patch_exec)
|
||||||
@@ -1913,8 +1975,8 @@ async fn handle_container_exec_with_params(
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
assess_safety_for_untrusted_command(
|
assess_safety_for_untrusted_command(
|
||||||
sess.approval_policy,
|
turn_context.approval_policy,
|
||||||
&sess.sandbox_policy,
|
&turn_context.sandbox_policy,
|
||||||
params.with_escalated_permissions.unwrap_or(false),
|
params.with_escalated_permissions.unwrap_or(false),
|
||||||
)
|
)
|
||||||
};
|
};
|
||||||
@@ -1929,8 +1991,8 @@ async fn handle_container_exec_with_params(
|
|||||||
let state = sess.state.lock_unchecked();
|
let state = sess.state.lock_unchecked();
|
||||||
assess_command_safety(
|
assess_command_safety(
|
||||||
¶ms.command,
|
¶ms.command,
|
||||||
sess.approval_policy,
|
turn_context.approval_policy,
|
||||||
&sess.sandbox_policy,
|
&turn_context.sandbox_policy,
|
||||||
&state.approved_commands,
|
&state.approved_commands,
|
||||||
params.with_escalated_permissions.unwrap_or(false),
|
params.with_escalated_permissions.unwrap_or(false),
|
||||||
)
|
)
|
||||||
@@ -2000,7 +2062,7 @@ async fn handle_container_exec_with_params(
|
|||||||
),
|
),
|
||||||
};
|
};
|
||||||
|
|
||||||
let params = maybe_run_with_user_profile(params, sess);
|
let params = maybe_run_with_user_profile(params, sess, turn_context);
|
||||||
let output_result = sess
|
let output_result = sess
|
||||||
.run_exec_with_events(
|
.run_exec_with_events(
|
||||||
turn_diff_tracker,
|
turn_diff_tracker,
|
||||||
@@ -2008,7 +2070,7 @@ async fn handle_container_exec_with_params(
|
|||||||
ExecInvokeArgs {
|
ExecInvokeArgs {
|
||||||
params: params.clone(),
|
params: params.clone(),
|
||||||
sandbox_type,
|
sandbox_type,
|
||||||
sandbox_policy: &sess.sandbox_policy,
|
sandbox_policy: &turn_context.sandbox_policy,
|
||||||
codex_linux_sandbox_exe: &sess.codex_linux_sandbox_exe,
|
codex_linux_sandbox_exe: &sess.codex_linux_sandbox_exe,
|
||||||
stdout_stream: Some(StdoutStream {
|
stdout_stream: Some(StdoutStream {
|
||||||
sub_id: sub_id.clone(),
|
sub_id: sub_id.clone(),
|
||||||
@@ -2041,6 +2103,7 @@ async fn handle_container_exec_with_params(
|
|||||||
error,
|
error,
|
||||||
sandbox_type,
|
sandbox_type,
|
||||||
sess,
|
sess,
|
||||||
|
turn_context,
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
@@ -2061,6 +2124,7 @@ async fn handle_sandbox_error(
|
|||||||
error: SandboxErr,
|
error: SandboxErr,
|
||||||
sandbox_type: SandboxType,
|
sandbox_type: SandboxType,
|
||||||
sess: &Session,
|
sess: &Session,
|
||||||
|
turn_context: &TurnContext,
|
||||||
) -> ResponseInputItem {
|
) -> ResponseInputItem {
|
||||||
let call_id = exec_command_context.call_id.clone();
|
let call_id = exec_command_context.call_id.clone();
|
||||||
let sub_id = exec_command_context.sub_id.clone();
|
let sub_id = exec_command_context.sub_id.clone();
|
||||||
@@ -2068,7 +2132,7 @@ async fn handle_sandbox_error(
|
|||||||
|
|
||||||
// Early out if either the user never wants to be asked for approval, or
|
// Early out if either the user never wants to be asked for approval, or
|
||||||
// we're letting the model manage escalation requests. Otherwise, continue
|
// we're letting the model manage escalation requests. Otherwise, continue
|
||||||
match sess.approval_policy {
|
match turn_context.approval_policy {
|
||||||
AskForApproval::Never | AskForApproval::OnRequest => {
|
AskForApproval::Never | AskForApproval::OnRequest => {
|
||||||
return ResponseInputItem::FunctionCallOutput {
|
return ResponseInputItem::FunctionCallOutput {
|
||||||
call_id,
|
call_id,
|
||||||
@@ -2139,7 +2203,7 @@ async fn handle_sandbox_error(
|
|||||||
ExecInvokeArgs {
|
ExecInvokeArgs {
|
||||||
params,
|
params,
|
||||||
sandbox_type: SandboxType::None,
|
sandbox_type: SandboxType::None,
|
||||||
sandbox_policy: &sess.sandbox_policy,
|
sandbox_policy: &turn_context.sandbox_policy,
|
||||||
codex_linux_sandbox_exe: &sess.codex_linux_sandbox_exe,
|
codex_linux_sandbox_exe: &sess.codex_linux_sandbox_exe,
|
||||||
stdout_stream: Some(StdoutStream {
|
stdout_stream: Some(StdoutStream {
|
||||||
sub_id: sub_id.clone(),
|
sub_id: sub_id.clone(),
|
||||||
@@ -2253,8 +2317,13 @@ fn get_last_assistant_message_from_turn(responses: &[ResponseItem]) -> Option<St
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn drain_to_completed(sess: &Session, sub_id: &str, prompt: &Prompt) -> CodexResult<()> {
|
async fn drain_to_completed(
|
||||||
let mut stream = sess.client.clone().stream(prompt).await?;
|
sess: &Session,
|
||||||
|
turn_context: &TurnContext,
|
||||||
|
sub_id: &str,
|
||||||
|
prompt: &Prompt,
|
||||||
|
) -> CodexResult<()> {
|
||||||
|
let mut stream = turn_context.client.clone().stream(prompt).await?;
|
||||||
loop {
|
loop {
|
||||||
let maybe_event = stream.next().await;
|
let maybe_event = stream.next().await;
|
||||||
let Some(event) = maybe_event else {
|
let Some(event) = maybe_event else {
|
||||||
|
|||||||
Reference in New Issue
Block a user