Pass TurnContext around instead of sub_id (#5421)
Today `sub_id` is an ID of a single incoming Codex Op submition. We then associate all events triggered by this operation using the same `sub_id`. At the same time we are also creating a TurnContext per submission and we'd like to start associating some events (item added/item completed) with an entire turn instead of just the operation that started it. Using turn context when sending events give us flexibility to change notification scheme.
This commit is contained in:
@@ -36,7 +36,6 @@ pub(crate) struct ApplyPatchExec {
|
||||
pub(crate) async fn apply_patch(
|
||||
sess: &Session,
|
||||
turn_context: &TurnContext,
|
||||
sub_id: &str,
|
||||
call_id: &str,
|
||||
action: ApplyPatchAction,
|
||||
) -> InternalApplyPatchInvocation {
|
||||
@@ -62,7 +61,7 @@ pub(crate) async fn apply_patch(
|
||||
// that similar patches can be auto-approved in the future during
|
||||
// this session.
|
||||
let rx_approve = sess
|
||||
.request_patch_approval(sub_id.to_owned(), call_id.to_owned(), &action, None, None)
|
||||
.request_patch_approval(turn_context, call_id.to_owned(), &action, None, None)
|
||||
.await;
|
||||
match rx_approve.await.unwrap_or_default() {
|
||||
ReviewDecision::Approved | ReviewDecision::ApprovedForSession => {
|
||||
|
||||
@@ -255,6 +255,7 @@ pub(crate) struct Session {
|
||||
/// The context needed for a single turn of the conversation.
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct TurnContext {
|
||||
pub(crate) sub_id: String,
|
||||
pub(crate) client: ModelClient,
|
||||
/// The session's current working directory. All relative paths provided by
|
||||
/// the model as well as sandbox policies are resolved against this path
|
||||
@@ -359,6 +360,7 @@ impl Session {
|
||||
session_configuration: &SessionConfiguration,
|
||||
conversation_id: ConversationId,
|
||||
tx_event: Sender<Event>,
|
||||
sub_id: String,
|
||||
) -> TurnContext {
|
||||
let config = session_configuration.original_config_do_not_use.clone();
|
||||
let model_family = find_family_for_model(&session_configuration.model)
|
||||
@@ -392,7 +394,10 @@ impl Session {
|
||||
features: &config.features,
|
||||
});
|
||||
|
||||
let item_collector = ItemCollector::new(tx_event, conversation_id, sub_id.clone());
|
||||
|
||||
TurnContext {
|
||||
sub_id,
|
||||
client,
|
||||
cwd: session_configuration.cwd.clone(),
|
||||
base_instructions: session_configuration.base_instructions.clone(),
|
||||
@@ -404,7 +409,7 @@ impl Session {
|
||||
is_review_mode: false,
|
||||
final_output_json_schema: None,
|
||||
codex_linux_sandbox_exe: config.codex_linux_sandbox_exe.clone(),
|
||||
item_collector: ItemCollector::new(tx_event, conversation_id, "turn_id".to_string()),
|
||||
item_collector,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -587,7 +592,7 @@ impl Session {
|
||||
})
|
||||
.chain(post_session_configured_error_events.into_iter());
|
||||
for event in events {
|
||||
sess.send_event(event).await;
|
||||
sess.send_event_raw(event).await;
|
||||
}
|
||||
|
||||
Ok(sess)
|
||||
@@ -638,6 +643,15 @@ impl Session {
|
||||
}
|
||||
|
||||
pub(crate) async fn new_turn(&self, updates: SessionSettingsUpdate) -> Arc<TurnContext> {
|
||||
let sub_id = self.next_internal_sub_id();
|
||||
self.new_turn_with_sub_id(sub_id, updates).await
|
||||
}
|
||||
|
||||
pub(crate) async fn new_turn_with_sub_id(
|
||||
&self,
|
||||
sub_id: String,
|
||||
updates: SessionSettingsUpdate,
|
||||
) -> Arc<TurnContext> {
|
||||
let session_configuration = {
|
||||
let mut state = self.state.lock().await;
|
||||
let session_configuration = state.session_configuration.clone().apply(&updates);
|
||||
@@ -652,6 +666,7 @@ impl Session {
|
||||
&session_configuration,
|
||||
self.conversation_id,
|
||||
self.get_tx_event(),
|
||||
sub_id,
|
||||
);
|
||||
if let Some(final_schema) = updates.final_output_json_schema {
|
||||
turn_context.final_output_json_schema = final_schema;
|
||||
@@ -678,7 +693,15 @@ impl Session {
|
||||
}
|
||||
|
||||
/// Persist the event to rollout and send it to clients.
|
||||
pub(crate) async fn send_event(&self, event: Event) {
|
||||
pub(crate) async fn send_event(&self, turn_context: &TurnContext, msg: EventMsg) {
|
||||
let event = Event {
|
||||
id: turn_context.sub_id.clone(),
|
||||
msg,
|
||||
};
|
||||
self.send_event_raw(event).await;
|
||||
}
|
||||
|
||||
pub(crate) async fn send_event_raw(&self, event: Event) {
|
||||
// Persist the event into rollout (recorder filters as needed)
|
||||
let rollout_items = vec![RolloutItem::EventMsg(event.msg.clone())];
|
||||
self.persist_rollout_items(&rollout_items).await;
|
||||
@@ -694,12 +717,13 @@ impl Session {
|
||||
/// default `ReviewDecision` (`Denied`).
|
||||
pub async fn request_command_approval(
|
||||
&self,
|
||||
sub_id: String,
|
||||
turn_context: &TurnContext,
|
||||
call_id: String,
|
||||
command: Vec<String>,
|
||||
cwd: PathBuf,
|
||||
reason: Option<String>,
|
||||
) -> ReviewDecision {
|
||||
let sub_id = turn_context.sub_id.clone();
|
||||
// Add the tx_approve callback to the map before sending the request.
|
||||
let (tx_approve, rx_approve) = oneshot::channel();
|
||||
let event_id = sub_id.clone();
|
||||
@@ -718,28 +742,26 @@ impl Session {
|
||||
}
|
||||
|
||||
let parsed_cmd = parse_command(&command);
|
||||
let event = Event {
|
||||
id: event_id,
|
||||
msg: EventMsg::ExecApprovalRequest(ExecApprovalRequestEvent {
|
||||
call_id,
|
||||
command,
|
||||
cwd,
|
||||
reason,
|
||||
parsed_cmd,
|
||||
}),
|
||||
};
|
||||
self.send_event(event).await;
|
||||
let event = EventMsg::ExecApprovalRequest(ExecApprovalRequestEvent {
|
||||
call_id,
|
||||
command,
|
||||
cwd,
|
||||
reason,
|
||||
parsed_cmd,
|
||||
});
|
||||
self.send_event(turn_context, event).await;
|
||||
rx_approve.await.unwrap_or_default()
|
||||
}
|
||||
|
||||
pub async fn request_patch_approval(
|
||||
&self,
|
||||
sub_id: String,
|
||||
turn_context: &TurnContext,
|
||||
call_id: String,
|
||||
action: &ApplyPatchAction,
|
||||
reason: Option<String>,
|
||||
grant_root: Option<PathBuf>,
|
||||
) -> oneshot::Receiver<ReviewDecision> {
|
||||
let sub_id = turn_context.sub_id.clone();
|
||||
// Add the tx_approve callback to the map before sending the request.
|
||||
let (tx_approve, rx_approve) = oneshot::channel();
|
||||
let event_id = sub_id.clone();
|
||||
@@ -757,16 +779,13 @@ impl Session {
|
||||
warn!("Overwriting existing pending approval for sub_id: {event_id}");
|
||||
}
|
||||
|
||||
let event = Event {
|
||||
id: event_id,
|
||||
msg: EventMsg::ApplyPatchApprovalRequest(ApplyPatchApprovalRequestEvent {
|
||||
call_id,
|
||||
changes: convert_apply_patch_to_protocol(action),
|
||||
reason,
|
||||
grant_root,
|
||||
}),
|
||||
};
|
||||
self.send_event(event).await;
|
||||
let event = EventMsg::ApplyPatchApprovalRequest(ApplyPatchApprovalRequestEvent {
|
||||
call_id,
|
||||
changes: convert_apply_patch_to_protocol(action),
|
||||
reason,
|
||||
grant_root,
|
||||
});
|
||||
self.send_event(turn_context, event).await;
|
||||
rx_approve
|
||||
}
|
||||
|
||||
@@ -878,7 +897,6 @@ impl Session {
|
||||
|
||||
async fn update_token_usage_info(
|
||||
&self,
|
||||
sub_id: &str,
|
||||
turn_context: &TurnContext,
|
||||
token_usage: Option<&TokenUsage>,
|
||||
) {
|
||||
@@ -891,37 +909,38 @@ impl Session {
|
||||
);
|
||||
}
|
||||
}
|
||||
self.send_token_count_event(sub_id).await;
|
||||
self.send_token_count_event(turn_context).await;
|
||||
}
|
||||
|
||||
async fn update_rate_limits(&self, sub_id: &str, new_rate_limits: RateLimitSnapshot) {
|
||||
async fn update_rate_limits(
|
||||
&self,
|
||||
turn_context: &TurnContext,
|
||||
new_rate_limits: RateLimitSnapshot,
|
||||
) {
|
||||
{
|
||||
let mut state = self.state.lock().await;
|
||||
state.set_rate_limits(new_rate_limits);
|
||||
}
|
||||
self.send_token_count_event(sub_id).await;
|
||||
self.send_token_count_event(turn_context).await;
|
||||
}
|
||||
|
||||
async fn send_token_count_event(&self, sub_id: &str) {
|
||||
async fn send_token_count_event(&self, turn_context: &TurnContext) {
|
||||
let (info, rate_limits) = {
|
||||
let state = self.state.lock().await;
|
||||
state.token_info_and_rate_limits()
|
||||
};
|
||||
let event = Event {
|
||||
id: sub_id.to_string(),
|
||||
msg: EventMsg::TokenCount(TokenCountEvent { info, rate_limits }),
|
||||
};
|
||||
self.send_event(event).await;
|
||||
let event = EventMsg::TokenCount(TokenCountEvent { info, rate_limits });
|
||||
self.send_event(turn_context, event).await;
|
||||
}
|
||||
|
||||
async fn set_total_tokens_full(&self, sub_id: &str, turn_context: &TurnContext) {
|
||||
async fn set_total_tokens_full(&self, turn_context: &TurnContext) {
|
||||
let context_window = turn_context.client.get_model_context_window();
|
||||
if let Some(context_window) = context_window {
|
||||
{
|
||||
let mut state = self.state.lock().await;
|
||||
state.set_token_usage_full(context_window);
|
||||
}
|
||||
self.send_token_count_event(sub_id).await;
|
||||
self.send_token_count_event(turn_context).await;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -951,24 +970,22 @@ impl Session {
|
||||
/// Helper that emits a BackgroundEvent with the given message. This keeps
|
||||
/// the call‑sites terse so adding more diagnostics does not clutter the
|
||||
/// core agent logic.
|
||||
pub(crate) async fn notify_background_event(&self, sub_id: &str, message: impl Into<String>) {
|
||||
let event = Event {
|
||||
id: sub_id.to_string(),
|
||||
msg: EventMsg::BackgroundEvent(BackgroundEventEvent {
|
||||
message: message.into(),
|
||||
}),
|
||||
};
|
||||
self.send_event(event).await;
|
||||
pub(crate) async fn notify_background_event(
|
||||
&self,
|
||||
turn_context: &TurnContext,
|
||||
message: impl Into<String>,
|
||||
) {
|
||||
let event = EventMsg::BackgroundEvent(BackgroundEventEvent {
|
||||
message: message.into(),
|
||||
});
|
||||
self.send_event(turn_context, event).await;
|
||||
}
|
||||
|
||||
async fn notify_stream_error(&self, sub_id: &str, message: impl Into<String>) {
|
||||
let event = Event {
|
||||
id: sub_id.to_string(),
|
||||
msg: EventMsg::StreamError(StreamErrorEvent {
|
||||
message: message.into(),
|
||||
}),
|
||||
};
|
||||
self.send_event(event).await;
|
||||
async fn notify_stream_error(&self, turn_context: &TurnContext, message: impl Into<String>) {
|
||||
let event = EventMsg::StreamError(StreamErrorEvent {
|
||||
message: message.into(),
|
||||
});
|
||||
self.send_event(turn_context, event).await;
|
||||
}
|
||||
|
||||
/// Build the full turn input by concatenating the current conversation
|
||||
@@ -1129,7 +1146,7 @@ async fn submission_loop(sess: Arc<Session>, config: Arc<Config>, rx_sub: Receiv
|
||||
Op::UserInput { items } => (items, SessionSettingsUpdate::default()),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
let current_context = sess.new_turn(updates).await;
|
||||
let current_context = sess.new_turn_with_sub_id(sub.id.clone(), updates).await;
|
||||
current_context
|
||||
.client
|
||||
.get_otel_event_manager()
|
||||
@@ -1145,11 +1162,7 @@ async fn submission_loop(sess: Arc<Session>, config: Arc<Config>, rx_sub: Receiv
|
||||
&env_item,
|
||||
sess.show_raw_agent_reasoning(),
|
||||
) {
|
||||
let event = Event {
|
||||
id: sub.id.clone(),
|
||||
msg,
|
||||
};
|
||||
sess.send_event(event).await;
|
||||
sess.send_event(¤t_context, msg).await;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1158,7 +1171,7 @@ async fn submission_loop(sess: Arc<Session>, config: Arc<Config>, rx_sub: Receiv
|
||||
.started_completed(TurnItem::UserMessage(UserMessageItem::new(&items)))
|
||||
.await;
|
||||
|
||||
sess.spawn_task(Arc::clone(¤t_context), sub.id, items, RegularTask)
|
||||
sess.spawn_task(Arc::clone(¤t_context), items, RegularTask)
|
||||
.await;
|
||||
previous_context = Some(current_context);
|
||||
}
|
||||
@@ -1216,7 +1229,7 @@ async fn submission_loop(sess: Arc<Session>, config: Arc<Config>, rx_sub: Receiv
|
||||
),
|
||||
};
|
||||
|
||||
sess_clone.send_event(event).await;
|
||||
sess_clone.send_event_raw(event).await;
|
||||
});
|
||||
}
|
||||
Op::ListMcpTools => {
|
||||
@@ -1249,7 +1262,7 @@ async fn submission_loop(sess: Arc<Session>, config: Arc<Config>, rx_sub: Receiv
|
||||
},
|
||||
),
|
||||
};
|
||||
sess.send_event(event).await;
|
||||
sess.send_event_raw(event).await;
|
||||
}
|
||||
Op::ListCustomPrompts => {
|
||||
let sub_id = sub.id.clone();
|
||||
@@ -1267,10 +1280,12 @@ async fn submission_loop(sess: Arc<Session>, config: Arc<Config>, rx_sub: Receiv
|
||||
custom_prompts,
|
||||
}),
|
||||
};
|
||||
sess.send_event(event).await;
|
||||
sess.send_event_raw(event).await;
|
||||
}
|
||||
Op::Compact => {
|
||||
let turn_context = sess.new_turn(SessionSettingsUpdate::default()).await;
|
||||
let turn_context = sess
|
||||
.new_turn_with_sub_id(sub.id.clone(), SessionSettingsUpdate::default())
|
||||
.await;
|
||||
// Attempt to inject input into current task
|
||||
if let Err(items) = sess
|
||||
.inject_input(vec![UserInput::Text {
|
||||
@@ -1278,7 +1293,7 @@ async fn submission_loop(sess: Arc<Session>, config: Arc<Config>, rx_sub: Receiv
|
||||
}])
|
||||
.await
|
||||
{
|
||||
sess.spawn_task(Arc::clone(&turn_context), sub.id, items, CompactTask)
|
||||
sess.spawn_task(Arc::clone(&turn_context), items, CompactTask)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
@@ -1302,14 +1317,14 @@ async fn submission_loop(sess: Arc<Session>, config: Arc<Config>, rx_sub: Receiv
|
||||
message: "Failed to shutdown rollout recorder".to_string(),
|
||||
}),
|
||||
};
|
||||
sess.send_event(event).await;
|
||||
sess.send_event_raw(event).await;
|
||||
}
|
||||
|
||||
let event = Event {
|
||||
id: sub.id.clone(),
|
||||
msg: EventMsg::ShutdownComplete,
|
||||
};
|
||||
sess.send_event(event).await;
|
||||
sess.send_event_raw(event).await;
|
||||
break;
|
||||
}
|
||||
Op::GetPath => {
|
||||
@@ -1337,10 +1352,12 @@ async fn submission_loop(sess: Arc<Session>, config: Arc<Config>, rx_sub: Receiv
|
||||
path,
|
||||
}),
|
||||
};
|
||||
sess.send_event(event).await;
|
||||
sess.send_event_raw(event).await;
|
||||
}
|
||||
Op::Review { review_request } => {
|
||||
let turn_context = sess.new_turn(SessionSettingsUpdate::default()).await;
|
||||
let turn_context = sess
|
||||
.new_turn_with_sub_id(sub.id.clone(), SessionSettingsUpdate::default())
|
||||
.await;
|
||||
spawn_review_thread(
|
||||
sess.clone(),
|
||||
config.clone(),
|
||||
@@ -1416,6 +1433,7 @@ async fn spawn_review_thread(
|
||||
);
|
||||
|
||||
let review_turn_context = TurnContext {
|
||||
sub_id: sub_id.to_string(),
|
||||
client,
|
||||
tools_config,
|
||||
user_instructions: None,
|
||||
@@ -1439,17 +1457,11 @@ async fn spawn_review_thread(
|
||||
text: review_prompt,
|
||||
}];
|
||||
let tc = Arc::new(review_turn_context);
|
||||
|
||||
// Clone sub_id for the upcoming announcement before moving it into the task.
|
||||
let sub_id_for_event = sub_id.clone();
|
||||
sess.spawn_task(tc.clone(), sub_id, input, ReviewTask).await;
|
||||
sess.spawn_task(tc.clone(), input, ReviewTask).await;
|
||||
|
||||
// Announce entering review mode so UIs can switch modes.
|
||||
sess.send_event(Event {
|
||||
id: sub_id_for_event,
|
||||
msg: EventMsg::EnteredReviewMode(review_request),
|
||||
})
|
||||
.await;
|
||||
sess.send_event(&tc, EventMsg::EnteredReviewMode(review_request))
|
||||
.await;
|
||||
}
|
||||
|
||||
/// Takes a user message as input and runs a loop where, at each turn, the model
|
||||
@@ -1472,7 +1484,6 @@ async fn spawn_review_thread(
|
||||
pub(crate) async fn run_task(
|
||||
sess: Arc<Session>,
|
||||
turn_context: Arc<TurnContext>,
|
||||
sub_id: String,
|
||||
input: Vec<UserInput>,
|
||||
task_kind: TaskKind,
|
||||
cancellation_token: CancellationToken,
|
||||
@@ -1480,13 +1491,10 @@ pub(crate) async fn run_task(
|
||||
if input.is_empty() {
|
||||
return None;
|
||||
}
|
||||
let event = Event {
|
||||
id: sub_id.clone(),
|
||||
msg: EventMsg::TaskStarted(TaskStartedEvent {
|
||||
model_context_window: turn_context.client.get_model_context_window(),
|
||||
}),
|
||||
};
|
||||
sess.send_event(event).await;
|
||||
let event = EventMsg::TaskStarted(TaskStartedEvent {
|
||||
model_context_window: turn_context.client.get_model_context_window(),
|
||||
});
|
||||
sess.send_event(&turn_context, event).await;
|
||||
|
||||
let initial_input_for_turn: ResponseInputItem = ResponseInputItem::from(input);
|
||||
// For review threads, keep an isolated in-memory history so the
|
||||
@@ -1557,7 +1565,6 @@ pub(crate) async fn run_task(
|
||||
Arc::clone(&sess),
|
||||
Arc::clone(&turn_context),
|
||||
Arc::clone(&turn_diff_tracker),
|
||||
sub_id.clone(),
|
||||
turn_input,
|
||||
task_kind,
|
||||
cancellation_token.child_token(),
|
||||
@@ -1689,15 +1696,12 @@ pub(crate) async fn run_task(
|
||||
let current_tokens = total_usage_tokens
|
||||
.map(|tokens| tokens.to_string())
|
||||
.unwrap_or_else(|| "unknown".to_string());
|
||||
let event = Event {
|
||||
id: sub_id.clone(),
|
||||
msg: EventMsg::Error(ErrorEvent {
|
||||
message: format!(
|
||||
"Conversation is still above the token limit after automatic summarization (limit {limit_str}, current {current_tokens}). Please start a new session or trim your input."
|
||||
),
|
||||
}),
|
||||
};
|
||||
sess.send_event(event).await;
|
||||
let event = EventMsg::Error(ErrorEvent {
|
||||
message: format!(
|
||||
"Conversation is still above the token limit after automatic summarization (limit {limit_str}, current {current_tokens}). Please start a new session or trim your input."
|
||||
),
|
||||
});
|
||||
sess.send_event(&turn_context, event).await;
|
||||
break;
|
||||
}
|
||||
auto_compact_recently_attempted = true;
|
||||
@@ -1714,7 +1718,7 @@ pub(crate) async fn run_task(
|
||||
sess.notifier()
|
||||
.notify(&UserNotification::AgentTurnComplete {
|
||||
thread_id: sess.conversation_id.to_string(),
|
||||
turn_id: sub_id.clone(),
|
||||
turn_id: turn_context.sub_id.clone(),
|
||||
cwd: turn_context.cwd.display().to_string(),
|
||||
input_messages: turn_input_messages,
|
||||
last_assistant_message: last_agent_message.clone(),
|
||||
@@ -1729,13 +1733,10 @@ pub(crate) async fn run_task(
|
||||
}
|
||||
Err(e) => {
|
||||
info!("Turn error: {e:#}");
|
||||
let event = Event {
|
||||
id: sub_id.clone(),
|
||||
msg: EventMsg::Error(ErrorEvent {
|
||||
message: e.to_string(),
|
||||
}),
|
||||
};
|
||||
sess.send_event(event).await;
|
||||
let event = EventMsg::Error(ErrorEvent {
|
||||
message: e.to_string(),
|
||||
});
|
||||
sess.send_event(&turn_context, event).await;
|
||||
// let the user continue the conversation
|
||||
break;
|
||||
}
|
||||
@@ -1752,7 +1753,7 @@ pub(crate) async fn run_task(
|
||||
if turn_context.is_review_mode {
|
||||
exit_review_mode(
|
||||
sess.clone(),
|
||||
sub_id.clone(),
|
||||
Arc::clone(&turn_context),
|
||||
last_agent_message.as_deref().map(parse_review_output_event),
|
||||
)
|
||||
.await;
|
||||
@@ -1790,7 +1791,6 @@ async fn run_turn(
|
||||
sess: Arc<Session>,
|
||||
turn_context: Arc<TurnContext>,
|
||||
turn_diff_tracker: SharedTurnDiffTracker,
|
||||
sub_id: String,
|
||||
input: Vec<ResponseItem>,
|
||||
task_kind: TaskKind,
|
||||
cancellation_token: CancellationToken,
|
||||
@@ -1821,7 +1821,6 @@ async fn run_turn(
|
||||
Arc::clone(&sess),
|
||||
Arc::clone(&turn_context),
|
||||
Arc::clone(&turn_diff_tracker),
|
||||
&sub_id,
|
||||
&prompt,
|
||||
task_kind,
|
||||
cancellation_token.child_token(),
|
||||
@@ -1834,13 +1833,14 @@ async fn run_turn(
|
||||
Err(CodexErr::EnvVar(var)) => return Err(CodexErr::EnvVar(var)),
|
||||
Err(e @ CodexErr::Fatal(_)) => return Err(e),
|
||||
Err(e @ CodexErr::ContextWindowExceeded) => {
|
||||
sess.set_total_tokens_full(&sub_id, &turn_context).await;
|
||||
sess.set_total_tokens_full(turn_context.as_ref()).await;
|
||||
return Err(e);
|
||||
}
|
||||
Err(CodexErr::UsageLimitReached(e)) => {
|
||||
let rate_limits = e.rate_limits.clone();
|
||||
if let Some(rate_limits) = rate_limits {
|
||||
sess.update_rate_limits(&sub_id, rate_limits).await;
|
||||
sess.update_rate_limits(turn_context.as_ref(), rate_limits)
|
||||
.await;
|
||||
}
|
||||
return Err(CodexErr::UsageLimitReached(e));
|
||||
}
|
||||
@@ -1862,7 +1862,7 @@ async fn run_turn(
|
||||
// user understands what is happening instead of staring
|
||||
// at a seemingly frozen screen.
|
||||
sess.notify_stream_error(
|
||||
&sub_id,
|
||||
turn_context.as_ref(),
|
||||
format!("Re-connecting... {retries}/{max_retries}"),
|
||||
)
|
||||
.await;
|
||||
@@ -1898,7 +1898,6 @@ async fn try_run_turn(
|
||||
sess: Arc<Session>,
|
||||
turn_context: Arc<TurnContext>,
|
||||
turn_diff_tracker: SharedTurnDiffTracker,
|
||||
sub_id: &str,
|
||||
prompt: &Prompt,
|
||||
task_kind: TaskKind,
|
||||
cancellation_token: CancellationToken,
|
||||
@@ -1979,7 +1978,6 @@ async fn try_run_turn(
|
||||
Arc::clone(&sess),
|
||||
Arc::clone(&turn_context),
|
||||
Arc::clone(&turn_diff_tracker),
|
||||
sub_id.to_string(),
|
||||
);
|
||||
let mut output: FuturesOrdered<BoxFuture<CodexResult<ProcessedResponseItem>>> =
|
||||
FuturesOrdered::new();
|
||||
@@ -2028,7 +2026,6 @@ async fn try_run_turn(
|
||||
let response = handle_non_tool_response_item(
|
||||
Arc::clone(&sess),
|
||||
Arc::clone(&turn_context),
|
||||
sub_id,
|
||||
item.clone(),
|
||||
)
|
||||
.await?;
|
||||
@@ -2077,7 +2074,7 @@ async fn try_run_turn(
|
||||
let _ = sess
|
||||
.tx_event
|
||||
.send(Event {
|
||||
id: sub_id.to_string(),
|
||||
id: turn_context.sub_id.clone(),
|
||||
msg: EventMsg::WebSearchBegin(WebSearchBeginEvent { call_id }),
|
||||
})
|
||||
.await;
|
||||
@@ -2085,13 +2082,14 @@ async fn try_run_turn(
|
||||
ResponseEvent::RateLimits(snapshot) => {
|
||||
// Update internal state with latest rate limits, but defer sending until
|
||||
// token usage is available to avoid duplicate TokenCount events.
|
||||
sess.update_rate_limits(sub_id, snapshot).await;
|
||||
sess.update_rate_limits(turn_context.as_ref(), snapshot)
|
||||
.await;
|
||||
}
|
||||
ResponseEvent::Completed {
|
||||
response_id: _,
|
||||
token_usage,
|
||||
} => {
|
||||
sess.update_token_usage_info(sub_id, turn_context.as_ref(), token_usage.as_ref())
|
||||
sess.update_token_usage_info(turn_context.as_ref(), token_usage.as_ref())
|
||||
.await;
|
||||
|
||||
let processed_items = output
|
||||
@@ -2105,11 +2103,7 @@ async fn try_run_turn(
|
||||
};
|
||||
if let Ok(Some(unified_diff)) = unified_diff {
|
||||
let msg = EventMsg::TurnDiff(TurnDiffEvent { unified_diff });
|
||||
let event = Event {
|
||||
id: sub_id.to_string(),
|
||||
msg,
|
||||
};
|
||||
sess.send_event(event).await;
|
||||
sess.send_event(&turn_context, msg).await;
|
||||
}
|
||||
|
||||
let result = TurnRunResult {
|
||||
@@ -2123,38 +2117,27 @@ async fn try_run_turn(
|
||||
// In review child threads, suppress assistant text deltas; the
|
||||
// UI will show a selection popup from the final ReviewOutput.
|
||||
if !turn_context.is_review_mode {
|
||||
let event = Event {
|
||||
id: sub_id.to_string(),
|
||||
msg: EventMsg::AgentMessageDelta(AgentMessageDeltaEvent { delta }),
|
||||
};
|
||||
sess.send_event(event).await;
|
||||
let event = EventMsg::AgentMessageDelta(AgentMessageDeltaEvent { delta });
|
||||
sess.send_event(&turn_context, event).await;
|
||||
} else {
|
||||
trace!("suppressing OutputTextDelta in review mode");
|
||||
}
|
||||
}
|
||||
ResponseEvent::ReasoningSummaryDelta(delta) => {
|
||||
let event = Event {
|
||||
id: sub_id.to_string(),
|
||||
msg: EventMsg::AgentReasoningDelta(AgentReasoningDeltaEvent { delta }),
|
||||
};
|
||||
sess.send_event(event).await;
|
||||
let event = EventMsg::AgentReasoningDelta(AgentReasoningDeltaEvent { delta });
|
||||
sess.send_event(&turn_context, event).await;
|
||||
}
|
||||
ResponseEvent::ReasoningSummaryPartAdded => {
|
||||
let event = Event {
|
||||
id: sub_id.to_string(),
|
||||
msg: EventMsg::AgentReasoningSectionBreak(AgentReasoningSectionBreakEvent {}),
|
||||
};
|
||||
sess.send_event(event).await;
|
||||
let event =
|
||||
EventMsg::AgentReasoningSectionBreak(AgentReasoningSectionBreakEvent {});
|
||||
sess.send_event(&turn_context, event).await;
|
||||
}
|
||||
ResponseEvent::ReasoningContentDelta(delta) => {
|
||||
if sess.show_raw_agent_reasoning() {
|
||||
let event = Event {
|
||||
id: sub_id.to_string(),
|
||||
msg: EventMsg::AgentReasoningRawContentDelta(
|
||||
AgentReasoningRawContentDeltaEvent { delta },
|
||||
),
|
||||
};
|
||||
sess.send_event(event).await;
|
||||
let event = EventMsg::AgentReasoningRawContentDelta(
|
||||
AgentReasoningRawContentDeltaEvent { delta },
|
||||
);
|
||||
sess.send_event(&turn_context, event).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -2164,7 +2147,6 @@ async fn try_run_turn(
|
||||
async fn handle_non_tool_response_item(
|
||||
sess: Arc<Session>,
|
||||
turn_context: Arc<TurnContext>,
|
||||
sub_id: &str,
|
||||
item: ResponseItem,
|
||||
) -> CodexResult<Option<ResponseInputItem>> {
|
||||
debug!(?item, "Output item");
|
||||
@@ -2181,11 +2163,7 @@ async fn handle_non_tool_response_item(
|
||||
_ => map_response_item_to_event_messages(&item, sess.show_raw_agent_reasoning()),
|
||||
};
|
||||
for msg in msgs {
|
||||
let event = Event {
|
||||
id: sub_id.to_string(),
|
||||
msg,
|
||||
};
|
||||
sess.send_event(event).await;
|
||||
sess.send_event(&turn_context, msg).await;
|
||||
}
|
||||
}
|
||||
ResponseItem::FunctionCallOutput { .. } | ResponseItem::CustomToolCallOutput { .. } => {
|
||||
@@ -2255,16 +2233,13 @@ fn convert_call_tool_result_to_function_call_output_payload(
|
||||
/// and records a developer message with the review output.
|
||||
pub(crate) async fn exit_review_mode(
|
||||
session: Arc<Session>,
|
||||
task_sub_id: String,
|
||||
turn_context: Arc<TurnContext>,
|
||||
review_output: Option<ReviewOutputEvent>,
|
||||
) {
|
||||
let event = Event {
|
||||
id: task_sub_id,
|
||||
msg: EventMsg::ExitedReviewMode(ExitedReviewModeEvent {
|
||||
review_output: review_output.clone(),
|
||||
}),
|
||||
};
|
||||
session.send_event(event).await;
|
||||
let event = EventMsg::ExitedReviewMode(ExitedReviewModeEvent {
|
||||
review_output: review_output.clone(),
|
||||
});
|
||||
session.send_event(turn_context.as_ref(), event).await;
|
||||
|
||||
let mut user_message = String::new();
|
||||
if let Some(out) = review_output {
|
||||
@@ -2676,6 +2651,7 @@ mod tests {
|
||||
&session_configuration,
|
||||
conversation_id,
|
||||
tx_event.clone(),
|
||||
"turn_id".to_string(),
|
||||
);
|
||||
|
||||
let session = Session {
|
||||
@@ -2744,6 +2720,7 @@ mod tests {
|
||||
&session_configuration,
|
||||
conversation_id,
|
||||
tx_event.clone(),
|
||||
"turn_id".to_string(),
|
||||
));
|
||||
|
||||
let session = Arc::new(Session {
|
||||
@@ -2774,7 +2751,6 @@ mod tests {
|
||||
self: Arc<Self>,
|
||||
_session: Arc<SessionTaskContext>,
|
||||
_ctx: Arc<TurnContext>,
|
||||
_sub_id: String,
|
||||
_input: Vec<UserInput>,
|
||||
cancellation_token: CancellationToken,
|
||||
) -> Option<String> {
|
||||
@@ -2787,9 +2763,9 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
async fn abort(&self, session: Arc<SessionTaskContext>, sub_id: &str) {
|
||||
async fn abort(&self, session: Arc<SessionTaskContext>, ctx: Arc<TurnContext>) {
|
||||
if let TaskKind::Review = self.kind {
|
||||
exit_review_mode(session.clone_session(), sub_id.to_string(), None).await;
|
||||
exit_review_mode(session.clone_session(), ctx, None).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -2798,13 +2774,11 @@ mod tests {
|
||||
#[test_log::test]
|
||||
async fn abort_regular_task_emits_turn_aborted_only() {
|
||||
let (sess, tc, rx) = make_session_and_context_with_rx();
|
||||
let sub_id = "sub-regular".to_string();
|
||||
let input = vec![UserInput::Text {
|
||||
text: "hello".to_string(),
|
||||
}];
|
||||
sess.spawn_task(
|
||||
Arc::clone(&tc),
|
||||
sub_id.clone(),
|
||||
input,
|
||||
NeverEndingTask {
|
||||
kind: TaskKind::Regular,
|
||||
@@ -2829,13 +2803,11 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn abort_gracefuly_emits_turn_aborted_only() {
|
||||
let (sess, tc, rx) = make_session_and_context_with_rx();
|
||||
let sub_id = "sub-regular".to_string();
|
||||
let input = vec![UserInput::Text {
|
||||
text: "hello".to_string(),
|
||||
}];
|
||||
sess.spawn_task(
|
||||
Arc::clone(&tc),
|
||||
sub_id.clone(),
|
||||
input,
|
||||
NeverEndingTask {
|
||||
kind: TaskKind::Regular,
|
||||
@@ -2857,13 +2829,11 @@ mod tests {
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn abort_review_task_emits_exited_then_aborted_and_records_history() {
|
||||
let (sess, tc, rx) = make_session_and_context_with_rx();
|
||||
let sub_id = "sub-review".to_string();
|
||||
let input = vec![UserInput::Text {
|
||||
text: "start review".to_string(),
|
||||
}];
|
||||
sess.spawn_task(
|
||||
Arc::clone(&tc),
|
||||
sub_id.clone(),
|
||||
input,
|
||||
NeverEndingTask {
|
||||
kind: TaskKind::Review,
|
||||
@@ -2935,7 +2905,6 @@ mod tests {
|
||||
Arc::clone(&session),
|
||||
Arc::clone(&turn_context),
|
||||
tracker,
|
||||
"sub-id".to_string(),
|
||||
call,
|
||||
)
|
||||
.await
|
||||
@@ -3095,7 +3064,6 @@ mod tests {
|
||||
let turn_diff_tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::new()));
|
||||
|
||||
let tool_name = "shell";
|
||||
let sub_id = "test-sub".to_string();
|
||||
let call_id = "test-call".to_string();
|
||||
|
||||
let resp = handle_container_exec_with_params(
|
||||
@@ -3104,7 +3072,6 @@ mod tests {
|
||||
Arc::clone(&session),
|
||||
Arc::clone(&turn_context),
|
||||
Arc::clone(&turn_diff_tracker),
|
||||
sub_id,
|
||||
call_id,
|
||||
)
|
||||
.await;
|
||||
@@ -3132,7 +3099,6 @@ mod tests {
|
||||
Arc::clone(&session),
|
||||
Arc::clone(&turn_context),
|
||||
Arc::clone(&turn_diff_tracker),
|
||||
"test-sub".to_string(),
|
||||
"test-call-2".to_string(),
|
||||
)
|
||||
.await;
|
||||
|
||||
@@ -10,7 +10,6 @@ use crate::error::Result as CodexResult;
|
||||
use crate::protocol::AgentMessageEvent;
|
||||
use crate::protocol::CompactedItem;
|
||||
use crate::protocol::ErrorEvent;
|
||||
use crate::protocol::Event;
|
||||
use crate::protocol::EventMsg;
|
||||
use crate::protocol::InputMessageKind;
|
||||
use crate::protocol::TaskStartedEvent;
|
||||
@@ -40,34 +39,28 @@ pub(crate) async fn run_inline_auto_compact_task(
|
||||
sess: Arc<Session>,
|
||||
turn_context: Arc<TurnContext>,
|
||||
) {
|
||||
let sub_id = sess.next_internal_sub_id();
|
||||
let input = vec![UserInput::Text {
|
||||
text: SUMMARIZATION_PROMPT.to_string(),
|
||||
}];
|
||||
run_compact_task_inner(sess, turn_context, sub_id, input).await;
|
||||
run_compact_task_inner(sess, turn_context, input).await;
|
||||
}
|
||||
|
||||
pub(crate) async fn run_compact_task(
|
||||
sess: Arc<Session>,
|
||||
turn_context: Arc<TurnContext>,
|
||||
sub_id: String,
|
||||
input: Vec<UserInput>,
|
||||
) -> Option<String> {
|
||||
let start_event = Event {
|
||||
id: sub_id.clone(),
|
||||
msg: EventMsg::TaskStarted(TaskStartedEvent {
|
||||
model_context_window: turn_context.client.get_model_context_window(),
|
||||
}),
|
||||
};
|
||||
sess.send_event(start_event).await;
|
||||
run_compact_task_inner(sess.clone(), turn_context, sub_id.clone(), input).await;
|
||||
let start_event = EventMsg::TaskStarted(TaskStartedEvent {
|
||||
model_context_window: turn_context.client.get_model_context_window(),
|
||||
});
|
||||
sess.send_event(&turn_context, start_event).await;
|
||||
run_compact_task_inner(sess.clone(), turn_context, input).await;
|
||||
None
|
||||
}
|
||||
|
||||
async fn run_compact_task_inner(
|
||||
sess: Arc<Session>,
|
||||
turn_context: Arc<TurnContext>,
|
||||
sub_id: String,
|
||||
input: Vec<UserInput>,
|
||||
) {
|
||||
let initial_input_for_turn: ResponseInputItem = ResponseInputItem::from(input);
|
||||
@@ -94,14 +87,13 @@ async fn run_compact_task_inner(
|
||||
input: turn_input.clone(),
|
||||
..Default::default()
|
||||
};
|
||||
let attempt_result =
|
||||
drain_to_completed(&sess, turn_context.as_ref(), &sub_id, &prompt).await;
|
||||
let attempt_result = drain_to_completed(&sess, turn_context.as_ref(), &prompt).await;
|
||||
|
||||
match attempt_result {
|
||||
Ok(()) => {
|
||||
if truncated_count > 0 {
|
||||
sess.notify_background_event(
|
||||
&sub_id,
|
||||
turn_context.as_ref(),
|
||||
format!(
|
||||
"Trimmed {truncated_count} older conversation item(s) before compacting so the prompt fits the model context window."
|
||||
),
|
||||
@@ -120,15 +112,11 @@ async fn run_compact_task_inner(
|
||||
retries = 0;
|
||||
continue;
|
||||
}
|
||||
sess.set_total_tokens_full(&sub_id, turn_context.as_ref())
|
||||
.await;
|
||||
let event = Event {
|
||||
id: sub_id.clone(),
|
||||
msg: EventMsg::Error(ErrorEvent {
|
||||
message: e.to_string(),
|
||||
}),
|
||||
};
|
||||
sess.send_event(event).await;
|
||||
sess.set_total_tokens_full(turn_context.as_ref()).await;
|
||||
let event = EventMsg::Error(ErrorEvent {
|
||||
message: e.to_string(),
|
||||
});
|
||||
sess.send_event(&turn_context, event).await;
|
||||
return;
|
||||
}
|
||||
Err(e) => {
|
||||
@@ -136,20 +124,17 @@ async fn run_compact_task_inner(
|
||||
retries += 1;
|
||||
let delay = backoff(retries);
|
||||
sess.notify_stream_error(
|
||||
&sub_id,
|
||||
turn_context.as_ref(),
|
||||
format!("Re-connecting... {retries}/{max_retries}"),
|
||||
)
|
||||
.await;
|
||||
tokio::time::sleep(delay).await;
|
||||
continue;
|
||||
} else {
|
||||
let event = Event {
|
||||
id: sub_id.clone(),
|
||||
msg: EventMsg::Error(ErrorEvent {
|
||||
message: e.to_string(),
|
||||
}),
|
||||
};
|
||||
sess.send_event(event).await;
|
||||
let event = EventMsg::Error(ErrorEvent {
|
||||
message: e.to_string(),
|
||||
});
|
||||
sess.send_event(&turn_context, event).await;
|
||||
return;
|
||||
}
|
||||
}
|
||||
@@ -168,13 +153,10 @@ async fn run_compact_task_inner(
|
||||
});
|
||||
sess.persist_rollout_items(&[rollout_item]).await;
|
||||
|
||||
let event = Event {
|
||||
id: sub_id.clone(),
|
||||
msg: EventMsg::AgentMessage(AgentMessageEvent {
|
||||
message: "Compact task completed".to_string(),
|
||||
}),
|
||||
};
|
||||
sess.send_event(event).await;
|
||||
let event = EventMsg::AgentMessage(AgentMessageEvent {
|
||||
message: "Compact task completed".to_string(),
|
||||
});
|
||||
sess.send_event(&turn_context, event).await;
|
||||
}
|
||||
|
||||
pub fn content_items_to_text(content: &[ContentItem]) -> Option<String> {
|
||||
@@ -256,7 +238,6 @@ pub(crate) fn build_compacted_history(
|
||||
async fn drain_to_completed(
|
||||
sess: &Session,
|
||||
turn_context: &TurnContext,
|
||||
sub_id: &str,
|
||||
prompt: &Prompt,
|
||||
) -> CodexResult<()> {
|
||||
let mut stream = turn_context
|
||||
@@ -277,10 +258,10 @@ async fn drain_to_completed(
|
||||
sess.record_into_history(std::slice::from_ref(&item)).await;
|
||||
}
|
||||
Ok(ResponseEvent::RateLimits(snapshot)) => {
|
||||
sess.update_rate_limits(sub_id, snapshot).await;
|
||||
sess.update_rate_limits(turn_context, snapshot).await;
|
||||
}
|
||||
Ok(ResponseEvent::Completed { token_usage, .. }) => {
|
||||
sess.update_token_usage_info(sub_id, turn_context, token_usage.as_ref())
|
||||
sess.update_token_usage_info(turn_context, token_usage.as_ref())
|
||||
.await;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
@@ -3,7 +3,7 @@ use std::time::Instant;
|
||||
use tracing::error;
|
||||
|
||||
use crate::codex::Session;
|
||||
use crate::protocol::Event;
|
||||
use crate::codex::TurnContext;
|
||||
use crate::protocol::EventMsg;
|
||||
use crate::protocol::McpInvocation;
|
||||
use crate::protocol::McpToolCallBeginEvent;
|
||||
@@ -15,7 +15,7 @@ use codex_protocol::models::ResponseInputItem;
|
||||
/// `McpToolCallBegin` and `McpToolCallEnd` events to the `Session`.
|
||||
pub(crate) async fn handle_mcp_tool_call(
|
||||
sess: &Session,
|
||||
sub_id: &str,
|
||||
turn_context: &TurnContext,
|
||||
call_id: String,
|
||||
server: String,
|
||||
tool_name: String,
|
||||
@@ -51,7 +51,7 @@ pub(crate) async fn handle_mcp_tool_call(
|
||||
call_id: call_id.clone(),
|
||||
invocation: invocation.clone(),
|
||||
});
|
||||
notify_mcp_tool_call_event(sess, sub_id, tool_call_begin_event).await;
|
||||
notify_mcp_tool_call_event(sess, turn_context, tool_call_begin_event).await;
|
||||
|
||||
let start = Instant::now();
|
||||
// Perform the tool call.
|
||||
@@ -69,15 +69,11 @@ pub(crate) async fn handle_mcp_tool_call(
|
||||
result: result.clone(),
|
||||
});
|
||||
|
||||
notify_mcp_tool_call_event(sess, sub_id, tool_call_end_event.clone()).await;
|
||||
notify_mcp_tool_call_event(sess, turn_context, tool_call_end_event.clone()).await;
|
||||
|
||||
ResponseInputItem::McpToolCallOutput { call_id, result }
|
||||
}
|
||||
|
||||
async fn notify_mcp_tool_call_event(sess: &Session, sub_id: &str, event: EventMsg) {
|
||||
sess.send_event(Event {
|
||||
id: sub_id.to_string(),
|
||||
msg: event,
|
||||
})
|
||||
.await;
|
||||
async fn notify_mcp_tool_call_event(sess: &Session, turn_context: &TurnContext, event: EventMsg) {
|
||||
sess.send_event(turn_context, event).await;
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ use tokio_util::task::AbortOnDropHandle;
|
||||
use codex_protocol::models::ResponseInputItem;
|
||||
use tokio::sync::oneshot;
|
||||
|
||||
use crate::codex::TurnContext;
|
||||
use crate::protocol::ReviewDecision;
|
||||
use crate::tasks::SessionTask;
|
||||
|
||||
@@ -53,10 +54,12 @@ pub(crate) struct RunningTask {
|
||||
pub(crate) task: Arc<dyn SessionTask>,
|
||||
pub(crate) cancellation_token: CancellationToken,
|
||||
pub(crate) handle: Arc<AbortOnDropHandle<()>>,
|
||||
pub(crate) turn_context: Arc<TurnContext>,
|
||||
}
|
||||
|
||||
impl ActiveTurn {
|
||||
pub(crate) fn add_task(&mut self, sub_id: String, task: RunningTask) {
|
||||
pub(crate) fn add_task(&mut self, task: RunningTask) {
|
||||
let sub_id = task.turn_context.sub_id.clone();
|
||||
self.tasks.insert(sub_id, task);
|
||||
}
|
||||
|
||||
@@ -65,8 +68,8 @@ impl ActiveTurn {
|
||||
self.tasks.is_empty()
|
||||
}
|
||||
|
||||
pub(crate) fn drain_tasks(&mut self) -> IndexMap<String, RunningTask> {
|
||||
std::mem::take(&mut self.tasks)
|
||||
pub(crate) fn drain_tasks(&mut self) -> Vec<RunningTask> {
|
||||
self.tasks.drain(..).map(|(_, task)| task).collect()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -24,10 +24,9 @@ impl SessionTask for CompactTask {
|
||||
self: Arc<Self>,
|
||||
session: Arc<SessionTaskContext>,
|
||||
ctx: Arc<TurnContext>,
|
||||
sub_id: String,
|
||||
input: Vec<UserInput>,
|
||||
_cancellation_token: CancellationToken,
|
||||
) -> Option<String> {
|
||||
compact::run_compact_task(session.clone_session(), ctx, sub_id, input).await
|
||||
compact::run_compact_task(session.clone_session(), ctx, input).await
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,7 +15,6 @@ use tracing::warn;
|
||||
|
||||
use crate::codex::Session;
|
||||
use crate::codex::TurnContext;
|
||||
use crate::protocol::Event;
|
||||
use crate::protocol::EventMsg;
|
||||
use crate::protocol::TaskCompleteEvent;
|
||||
use crate::protocol::TurnAbortReason;
|
||||
@@ -55,13 +54,12 @@ pub(crate) trait SessionTask: Send + Sync + 'static {
|
||||
self: Arc<Self>,
|
||||
session: Arc<SessionTaskContext>,
|
||||
ctx: Arc<TurnContext>,
|
||||
sub_id: String,
|
||||
input: Vec<UserInput>,
|
||||
cancellation_token: CancellationToken,
|
||||
) -> Option<String>;
|
||||
|
||||
async fn abort(&self, session: Arc<SessionTaskContext>, sub_id: &str) {
|
||||
let _ = (session, sub_id);
|
||||
async fn abort(&self, session: Arc<SessionTaskContext>, ctx: Arc<TurnContext>) {
|
||||
let _ = (session, ctx);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -69,7 +67,6 @@ impl Session {
|
||||
pub async fn spawn_task<T: SessionTask>(
|
||||
self: &Arc<Self>,
|
||||
turn_context: Arc<TurnContext>,
|
||||
sub_id: String,
|
||||
input: Vec<UserInput>,
|
||||
task: T,
|
||||
) {
|
||||
@@ -86,14 +83,13 @@ impl Session {
|
||||
let session_ctx = Arc::new(SessionTaskContext::new(Arc::clone(self)));
|
||||
let ctx = Arc::clone(&turn_context);
|
||||
let task_for_run = Arc::clone(&task);
|
||||
let sub_clone = sub_id.clone();
|
||||
let task_cancellation_token = cancellation_token.child_token();
|
||||
tokio::spawn(async move {
|
||||
let ctx_for_finish = Arc::clone(&ctx);
|
||||
let last_agent_message = task_for_run
|
||||
.run(
|
||||
Arc::clone(&session_ctx),
|
||||
ctx,
|
||||
sub_clone.clone(),
|
||||
input,
|
||||
task_cancellation_token.child_token(),
|
||||
)
|
||||
@@ -102,7 +98,8 @@ impl Session {
|
||||
if !task_cancellation_token.is_cancelled() {
|
||||
// Emit completion uniformly from spawn site so all tasks share the same lifecycle.
|
||||
let sess = session_ctx.clone_session();
|
||||
sess.on_task_finished(sub_clone, last_agent_message).await;
|
||||
sess.on_task_finished(ctx_for_finish, last_agent_message)
|
||||
.await;
|
||||
}
|
||||
done_clone.notify_waiters();
|
||||
})
|
||||
@@ -114,60 +111,54 @@ impl Session {
|
||||
kind: task_kind,
|
||||
task,
|
||||
cancellation_token,
|
||||
turn_context: Arc::clone(&turn_context),
|
||||
};
|
||||
self.register_new_active_task(sub_id, running_task).await;
|
||||
self.register_new_active_task(running_task).await;
|
||||
}
|
||||
|
||||
pub async fn abort_all_tasks(self: &Arc<Self>, reason: TurnAbortReason) {
|
||||
for (sub_id, task) in self.take_all_running_tasks().await {
|
||||
self.handle_task_abort(sub_id, task, reason.clone()).await;
|
||||
for task in self.take_all_running_tasks().await {
|
||||
self.handle_task_abort(task, reason.clone()).await;
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn on_task_finished(
|
||||
self: &Arc<Self>,
|
||||
sub_id: String,
|
||||
turn_context: Arc<TurnContext>,
|
||||
last_agent_message: Option<String>,
|
||||
) {
|
||||
let mut active = self.active_turn.lock().await;
|
||||
if let Some(at) = active.as_mut()
|
||||
&& at.remove_task(&sub_id)
|
||||
&& at.remove_task(&turn_context.sub_id)
|
||||
{
|
||||
*active = None;
|
||||
}
|
||||
drop(active);
|
||||
let event = Event {
|
||||
id: sub_id,
|
||||
msg: EventMsg::TaskComplete(TaskCompleteEvent { last_agent_message }),
|
||||
};
|
||||
self.send_event(event).await;
|
||||
let event = EventMsg::TaskComplete(TaskCompleteEvent { last_agent_message });
|
||||
self.send_event(turn_context.as_ref(), event).await;
|
||||
}
|
||||
|
||||
async fn register_new_active_task(&self, sub_id: String, task: RunningTask) {
|
||||
async fn register_new_active_task(&self, task: RunningTask) {
|
||||
let mut active = self.active_turn.lock().await;
|
||||
let mut turn = ActiveTurn::default();
|
||||
turn.add_task(sub_id, task);
|
||||
turn.add_task(task);
|
||||
*active = Some(turn);
|
||||
}
|
||||
|
||||
async fn take_all_running_tasks(&self) -> Vec<(String, RunningTask)> {
|
||||
async fn take_all_running_tasks(&self) -> Vec<RunningTask> {
|
||||
let mut active = self.active_turn.lock().await;
|
||||
match active.take() {
|
||||
Some(mut at) => {
|
||||
at.clear_pending().await;
|
||||
let tasks = at.drain_tasks();
|
||||
tasks.into_iter().collect()
|
||||
|
||||
at.drain_tasks()
|
||||
}
|
||||
None => Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_task_abort(
|
||||
self: &Arc<Self>,
|
||||
sub_id: String,
|
||||
task: RunningTask,
|
||||
reason: TurnAbortReason,
|
||||
) {
|
||||
async fn handle_task_abort(self: &Arc<Self>, task: RunningTask, reason: TurnAbortReason) {
|
||||
let sub_id = task.turn_context.sub_id.clone();
|
||||
if task.cancellation_token.is_cancelled() {
|
||||
return;
|
||||
}
|
||||
@@ -187,13 +178,12 @@ impl Session {
|
||||
task.handle.abort();
|
||||
|
||||
let session_ctx = Arc::new(SessionTaskContext::new(Arc::clone(self)));
|
||||
session_task.abort(session_ctx, &sub_id).await;
|
||||
session_task
|
||||
.abort(session_ctx, Arc::clone(&task.turn_context))
|
||||
.await;
|
||||
|
||||
let event = Event {
|
||||
id: sub_id.clone(),
|
||||
msg: EventMsg::TurnAborted(TurnAbortedEvent { reason }),
|
||||
};
|
||||
self.send_event(event).await;
|
||||
let event = EventMsg::TurnAborted(TurnAbortedEvent { reason });
|
||||
self.send_event(task.turn_context.as_ref(), event).await;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -24,19 +24,10 @@ impl SessionTask for RegularTask {
|
||||
self: Arc<Self>,
|
||||
session: Arc<SessionTaskContext>,
|
||||
ctx: Arc<TurnContext>,
|
||||
sub_id: String,
|
||||
input: Vec<UserInput>,
|
||||
cancellation_token: CancellationToken,
|
||||
) -> Option<String> {
|
||||
let sess = session.clone_session();
|
||||
run_task(
|
||||
sess,
|
||||
ctx,
|
||||
sub_id,
|
||||
input,
|
||||
TaskKind::Regular,
|
||||
cancellation_token,
|
||||
)
|
||||
.await
|
||||
run_task(sess, ctx, input, TaskKind::Regular, cancellation_token).await
|
||||
}
|
||||
}
|
||||
|
||||
@@ -25,23 +25,14 @@ impl SessionTask for ReviewTask {
|
||||
self: Arc<Self>,
|
||||
session: Arc<SessionTaskContext>,
|
||||
ctx: Arc<TurnContext>,
|
||||
sub_id: String,
|
||||
input: Vec<UserInput>,
|
||||
cancellation_token: CancellationToken,
|
||||
) -> Option<String> {
|
||||
let sess = session.clone_session();
|
||||
run_task(
|
||||
sess,
|
||||
ctx,
|
||||
sub_id,
|
||||
input,
|
||||
TaskKind::Review,
|
||||
cancellation_token,
|
||||
)
|
||||
.await
|
||||
run_task(sess, ctx, input, TaskKind::Review, cancellation_token).await
|
||||
}
|
||||
|
||||
async fn abort(&self, session: Arc<SessionTaskContext>, sub_id: &str) {
|
||||
exit_review_mode(session.clone_session(), sub_id.to_string(), None).await;
|
||||
async fn abort(&self, session: Arc<SessionTaskContext>, ctx: Arc<TurnContext>) {
|
||||
exit_review_mode(session.clone_session(), ctx, None).await;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -24,7 +24,6 @@ pub struct ToolInvocation {
|
||||
pub session: Arc<Session>,
|
||||
pub turn: Arc<TurnContext>,
|
||||
pub tracker: SharedTurnDiffTracker,
|
||||
pub sub_id: String,
|
||||
pub call_id: String,
|
||||
pub tool_name: String,
|
||||
pub payload: ToolPayload,
|
||||
@@ -234,7 +233,7 @@ mod tests {
|
||||
#[derive(Clone, Debug)]
|
||||
#[allow(dead_code)]
|
||||
pub(crate) struct ExecCommandContext {
|
||||
pub(crate) sub_id: String,
|
||||
pub(crate) turn: Arc<TurnContext>,
|
||||
pub(crate) call_id: String,
|
||||
pub(crate) command_for_display: Vec<String>,
|
||||
pub(crate) cwd: PathBuf,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use crate::codex::Session;
|
||||
use crate::codex::TurnContext;
|
||||
use crate::exec::ExecToolCallOutput;
|
||||
use crate::parse_command::parse_command;
|
||||
use crate::protocol::Event;
|
||||
use crate::protocol::EventMsg;
|
||||
use crate::protocol::ExecCommandBeginEvent;
|
||||
use crate::protocol::ExecCommandEndEvent;
|
||||
@@ -20,7 +20,7 @@ use super::format_exec_output_str;
|
||||
#[derive(Clone, Copy)]
|
||||
pub(crate) struct ToolEventCtx<'a> {
|
||||
pub session: &'a Session,
|
||||
pub sub_id: &'a str,
|
||||
pub turn: &'a TurnContext,
|
||||
pub call_id: &'a str,
|
||||
pub turn_diff_tracker: Option<&'a SharedTurnDiffTracker>,
|
||||
}
|
||||
@@ -28,13 +28,13 @@ pub(crate) struct ToolEventCtx<'a> {
|
||||
impl<'a> ToolEventCtx<'a> {
|
||||
pub fn new(
|
||||
session: &'a Session,
|
||||
sub_id: &'a str,
|
||||
turn: &'a TurnContext,
|
||||
call_id: &'a str,
|
||||
turn_diff_tracker: Option<&'a SharedTurnDiffTracker>,
|
||||
) -> Self {
|
||||
Self {
|
||||
session,
|
||||
sub_id,
|
||||
turn,
|
||||
call_id,
|
||||
turn_diff_tracker,
|
||||
}
|
||||
@@ -79,15 +79,15 @@ impl ToolEmitter {
|
||||
match (self, stage) {
|
||||
(Self::Shell { command, cwd }, ToolEventStage::Begin) => {
|
||||
ctx.session
|
||||
.send_event(Event {
|
||||
id: ctx.sub_id.to_string(),
|
||||
msg: EventMsg::ExecCommandBegin(ExecCommandBeginEvent {
|
||||
.send_event(
|
||||
ctx.turn,
|
||||
EventMsg::ExecCommandBegin(ExecCommandBeginEvent {
|
||||
call_id: ctx.call_id.to_string(),
|
||||
command: command.clone(),
|
||||
cwd: cwd.clone(),
|
||||
parsed_cmd: parse_command(command),
|
||||
}),
|
||||
})
|
||||
)
|
||||
.await;
|
||||
}
|
||||
(Self::Shell { .. }, ToolEventStage::Success(output)) => {
|
||||
@@ -139,14 +139,14 @@ impl ToolEmitter {
|
||||
guard.on_patch_begin(changes);
|
||||
}
|
||||
ctx.session
|
||||
.send_event(Event {
|
||||
id: ctx.sub_id.to_string(),
|
||||
msg: EventMsg::PatchApplyBegin(PatchApplyBeginEvent {
|
||||
.send_event(
|
||||
ctx.turn,
|
||||
EventMsg::PatchApplyBegin(PatchApplyBeginEvent {
|
||||
call_id: ctx.call_id.to_string(),
|
||||
auto_approved: *auto_approved,
|
||||
changes: changes.clone(),
|
||||
}),
|
||||
})
|
||||
)
|
||||
.await;
|
||||
}
|
||||
(Self::ApplyPatch { .. }, ToolEventStage::Success(output)) => {
|
||||
@@ -190,9 +190,9 @@ async fn emit_exec_end(
|
||||
formatted_output: String,
|
||||
) {
|
||||
ctx.session
|
||||
.send_event(Event {
|
||||
id: ctx.sub_id.to_string(),
|
||||
msg: EventMsg::ExecCommandEnd(ExecCommandEndEvent {
|
||||
.send_event(
|
||||
ctx.turn,
|
||||
EventMsg::ExecCommandEnd(ExecCommandEndEvent {
|
||||
call_id: ctx.call_id.to_string(),
|
||||
stdout,
|
||||
stderr,
|
||||
@@ -201,21 +201,21 @@ async fn emit_exec_end(
|
||||
duration,
|
||||
formatted_output,
|
||||
}),
|
||||
})
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
async fn emit_patch_end(ctx: ToolEventCtx<'_>, stdout: String, stderr: String, success: bool) {
|
||||
ctx.session
|
||||
.send_event(Event {
|
||||
id: ctx.sub_id.to_string(),
|
||||
msg: EventMsg::PatchApplyEnd(PatchApplyEndEvent {
|
||||
.send_event(
|
||||
ctx.turn,
|
||||
EventMsg::PatchApplyEnd(PatchApplyEndEvent {
|
||||
call_id: ctx.call_id.to_string(),
|
||||
stdout,
|
||||
stderr,
|
||||
success,
|
||||
}),
|
||||
})
|
||||
)
|
||||
.await;
|
||||
|
||||
if let Some(tracker) = ctx.turn_diff_tracker {
|
||||
@@ -225,10 +225,7 @@ async fn emit_patch_end(ctx: ToolEventCtx<'_>, stdout: String, stderr: String, s
|
||||
};
|
||||
if let Ok(Some(unified_diff)) = unified_diff {
|
||||
ctx.session
|
||||
.send_event(Event {
|
||||
id: ctx.sub_id.to_string(),
|
||||
msg: EventMsg::TurnDiff(TurnDiffEvent { unified_diff }),
|
||||
})
|
||||
.send_event(ctx.turn, EventMsg::TurnDiff(TurnDiffEvent { unified_diff }))
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -42,7 +42,6 @@ impl ToolHandler for ApplyPatchHandler {
|
||||
session,
|
||||
turn,
|
||||
tracker,
|
||||
sub_id,
|
||||
call_id,
|
||||
tool_name,
|
||||
payload,
|
||||
@@ -81,7 +80,6 @@ impl ToolHandler for ApplyPatchHandler {
|
||||
Arc::clone(&session),
|
||||
Arc::clone(&turn),
|
||||
Arc::clone(&tracker),
|
||||
sub_id.clone(),
|
||||
call_id.clone(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
@@ -19,7 +19,7 @@ impl ToolHandler for McpHandler {
|
||||
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
|
||||
let ToolInvocation {
|
||||
session,
|
||||
sub_id,
|
||||
turn,
|
||||
call_id,
|
||||
payload,
|
||||
..
|
||||
@@ -43,7 +43,7 @@ impl ToolHandler for McpHandler {
|
||||
|
||||
let response = handle_mcp_tool_call(
|
||||
session.as_ref(),
|
||||
&sub_id,
|
||||
turn.as_ref(),
|
||||
call_id.clone(),
|
||||
server,
|
||||
tool,
|
||||
|
||||
@@ -21,8 +21,8 @@ use serde::de::DeserializeOwned;
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::codex::Session;
|
||||
use crate::codex::TurnContext;
|
||||
use crate::function_tool::FunctionCallError;
|
||||
use crate::protocol::Event;
|
||||
use crate::protocol::EventMsg;
|
||||
use crate::protocol::McpInvocation;
|
||||
use crate::protocol::McpToolCallBeginEvent;
|
||||
@@ -189,7 +189,7 @@ impl ToolHandler for McpResourceHandler {
|
||||
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
|
||||
let ToolInvocation {
|
||||
session,
|
||||
sub_id,
|
||||
turn,
|
||||
call_id,
|
||||
tool_name,
|
||||
payload,
|
||||
@@ -211,7 +211,7 @@ impl ToolHandler for McpResourceHandler {
|
||||
"list_mcp_resources" => {
|
||||
handle_list_resources(
|
||||
Arc::clone(&session),
|
||||
sub_id.clone(),
|
||||
Arc::clone(&turn),
|
||||
call_id.clone(),
|
||||
arguments_value.clone(),
|
||||
)
|
||||
@@ -220,14 +220,20 @@ impl ToolHandler for McpResourceHandler {
|
||||
"list_mcp_resource_templates" => {
|
||||
handle_list_resource_templates(
|
||||
Arc::clone(&session),
|
||||
sub_id.clone(),
|
||||
Arc::clone(&turn),
|
||||
call_id.clone(),
|
||||
arguments_value.clone(),
|
||||
)
|
||||
.await
|
||||
}
|
||||
"read_mcp_resource" => {
|
||||
handle_read_resource(Arc::clone(&session), sub_id, call_id, arguments_value).await
|
||||
handle_read_resource(
|
||||
Arc::clone(&session),
|
||||
Arc::clone(&turn),
|
||||
call_id,
|
||||
arguments_value,
|
||||
)
|
||||
.await
|
||||
}
|
||||
other => Err(FunctionCallError::RespondToModel(format!(
|
||||
"unsupported MCP resource tool: {other}"
|
||||
@@ -238,7 +244,7 @@ impl ToolHandler for McpResourceHandler {
|
||||
|
||||
async fn handle_list_resources(
|
||||
session: Arc<Session>,
|
||||
sub_id: String,
|
||||
turn: Arc<TurnContext>,
|
||||
call_id: String,
|
||||
arguments: Option<Value>,
|
||||
) -> Result<ToolOutput, FunctionCallError> {
|
||||
@@ -253,7 +259,7 @@ async fn handle_list_resources(
|
||||
arguments: arguments.clone(),
|
||||
};
|
||||
|
||||
emit_tool_call_begin(&session, &sub_id, &call_id, invocation.clone()).await;
|
||||
emit_tool_call_begin(&session, turn.as_ref(), &call_id, invocation.clone()).await;
|
||||
let start = Instant::now();
|
||||
|
||||
let payload_result: Result<ListResourcesPayload, FunctionCallError> = async {
|
||||
@@ -297,7 +303,7 @@ async fn handle_list_resources(
|
||||
let duration = start.elapsed();
|
||||
emit_tool_call_end(
|
||||
&session,
|
||||
&sub_id,
|
||||
turn.as_ref(),
|
||||
&call_id,
|
||||
invocation,
|
||||
duration,
|
||||
@@ -311,7 +317,7 @@ async fn handle_list_resources(
|
||||
let message = err.to_string();
|
||||
emit_tool_call_end(
|
||||
&session,
|
||||
&sub_id,
|
||||
turn.as_ref(),
|
||||
&call_id,
|
||||
invocation,
|
||||
duration,
|
||||
@@ -326,7 +332,7 @@ async fn handle_list_resources(
|
||||
let message = err.to_string();
|
||||
emit_tool_call_end(
|
||||
&session,
|
||||
&sub_id,
|
||||
turn.as_ref(),
|
||||
&call_id,
|
||||
invocation,
|
||||
duration,
|
||||
@@ -340,7 +346,7 @@ async fn handle_list_resources(
|
||||
|
||||
async fn handle_list_resource_templates(
|
||||
session: Arc<Session>,
|
||||
sub_id: String,
|
||||
turn: Arc<TurnContext>,
|
||||
call_id: String,
|
||||
arguments: Option<Value>,
|
||||
) -> Result<ToolOutput, FunctionCallError> {
|
||||
@@ -355,7 +361,7 @@ async fn handle_list_resource_templates(
|
||||
arguments: arguments.clone(),
|
||||
};
|
||||
|
||||
emit_tool_call_begin(&session, &sub_id, &call_id, invocation.clone()).await;
|
||||
emit_tool_call_begin(&session, turn.as_ref(), &call_id, invocation.clone()).await;
|
||||
let start = Instant::now();
|
||||
|
||||
let payload_result: Result<ListResourceTemplatesPayload, FunctionCallError> = async {
|
||||
@@ -403,7 +409,7 @@ async fn handle_list_resource_templates(
|
||||
let duration = start.elapsed();
|
||||
emit_tool_call_end(
|
||||
&session,
|
||||
&sub_id,
|
||||
turn.as_ref(),
|
||||
&call_id,
|
||||
invocation,
|
||||
duration,
|
||||
@@ -417,7 +423,7 @@ async fn handle_list_resource_templates(
|
||||
let message = err.to_string();
|
||||
emit_tool_call_end(
|
||||
&session,
|
||||
&sub_id,
|
||||
turn.as_ref(),
|
||||
&call_id,
|
||||
invocation,
|
||||
duration,
|
||||
@@ -432,7 +438,7 @@ async fn handle_list_resource_templates(
|
||||
let message = err.to_string();
|
||||
emit_tool_call_end(
|
||||
&session,
|
||||
&sub_id,
|
||||
turn.as_ref(),
|
||||
&call_id,
|
||||
invocation,
|
||||
duration,
|
||||
@@ -446,7 +452,7 @@ async fn handle_list_resource_templates(
|
||||
|
||||
async fn handle_read_resource(
|
||||
session: Arc<Session>,
|
||||
sub_id: String,
|
||||
turn: Arc<TurnContext>,
|
||||
call_id: String,
|
||||
arguments: Option<Value>,
|
||||
) -> Result<ToolOutput, FunctionCallError> {
|
||||
@@ -461,7 +467,7 @@ async fn handle_read_resource(
|
||||
arguments: arguments.clone(),
|
||||
};
|
||||
|
||||
emit_tool_call_begin(&session, &sub_id, &call_id, invocation.clone()).await;
|
||||
emit_tool_call_begin(&session, turn.as_ref(), &call_id, invocation.clone()).await;
|
||||
let start = Instant::now();
|
||||
|
||||
let payload_result: Result<ReadResourcePayload, FunctionCallError> = async {
|
||||
@@ -489,7 +495,7 @@ async fn handle_read_resource(
|
||||
let duration = start.elapsed();
|
||||
emit_tool_call_end(
|
||||
&session,
|
||||
&sub_id,
|
||||
turn.as_ref(),
|
||||
&call_id,
|
||||
invocation,
|
||||
duration,
|
||||
@@ -503,7 +509,7 @@ async fn handle_read_resource(
|
||||
let message = err.to_string();
|
||||
emit_tool_call_end(
|
||||
&session,
|
||||
&sub_id,
|
||||
turn.as_ref(),
|
||||
&call_id,
|
||||
invocation,
|
||||
duration,
|
||||
@@ -518,7 +524,7 @@ async fn handle_read_resource(
|
||||
let message = err.to_string();
|
||||
emit_tool_call_end(
|
||||
&session,
|
||||
&sub_id,
|
||||
turn.as_ref(),
|
||||
&call_id,
|
||||
invocation,
|
||||
duration,
|
||||
@@ -544,39 +550,39 @@ fn call_tool_result_from_content(content: &str, success: Option<bool>) -> CallTo
|
||||
|
||||
async fn emit_tool_call_begin(
|
||||
session: &Arc<Session>,
|
||||
sub_id: &str,
|
||||
turn: &TurnContext,
|
||||
call_id: &str,
|
||||
invocation: McpInvocation,
|
||||
) {
|
||||
session
|
||||
.send_event(Event {
|
||||
id: sub_id.to_string(),
|
||||
msg: EventMsg::McpToolCallBegin(McpToolCallBeginEvent {
|
||||
.send_event(
|
||||
turn,
|
||||
EventMsg::McpToolCallBegin(McpToolCallBeginEvent {
|
||||
call_id: call_id.to_string(),
|
||||
invocation,
|
||||
}),
|
||||
})
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
async fn emit_tool_call_end(
|
||||
session: &Arc<Session>,
|
||||
sub_id: &str,
|
||||
turn: &TurnContext,
|
||||
call_id: &str,
|
||||
invocation: McpInvocation,
|
||||
duration: Duration,
|
||||
result: Result<CallToolResult, String>,
|
||||
) {
|
||||
session
|
||||
.send_event(Event {
|
||||
id: sub_id.to_string(),
|
||||
msg: EventMsg::McpToolCallEnd(McpToolCallEndEvent {
|
||||
.send_event(
|
||||
turn,
|
||||
EventMsg::McpToolCallEnd(McpToolCallEndEvent {
|
||||
call_id: call_id.to_string(),
|
||||
invocation,
|
||||
duration,
|
||||
result,
|
||||
}),
|
||||
})
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
use crate::client_common::tools::ResponsesApiTool;
|
||||
use crate::client_common::tools::ToolSpec;
|
||||
use crate::codex::Session;
|
||||
use crate::codex::TurnContext;
|
||||
use crate::function_tool::FunctionCallError;
|
||||
use crate::tools::context::ToolInvocation;
|
||||
use crate::tools::context::ToolOutput;
|
||||
@@ -10,7 +11,6 @@ use crate::tools::registry::ToolKind;
|
||||
use crate::tools::spec::JsonSchema;
|
||||
use async_trait::async_trait;
|
||||
use codex_protocol::plan_tool::UpdatePlanArgs;
|
||||
use codex_protocol::protocol::Event;
|
||||
use codex_protocol::protocol::EventMsg;
|
||||
use std::collections::BTreeMap;
|
||||
use std::sync::LazyLock;
|
||||
@@ -68,7 +68,7 @@ impl ToolHandler for PlanHandler {
|
||||
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
|
||||
let ToolInvocation {
|
||||
session,
|
||||
sub_id,
|
||||
turn,
|
||||
call_id,
|
||||
payload,
|
||||
..
|
||||
@@ -84,7 +84,7 @@ impl ToolHandler for PlanHandler {
|
||||
};
|
||||
|
||||
let content =
|
||||
handle_update_plan(session.as_ref(), arguments, sub_id.clone(), call_id).await?;
|
||||
handle_update_plan(session.as_ref(), turn.as_ref(), arguments, call_id).await?;
|
||||
|
||||
Ok(ToolOutput::Function {
|
||||
content,
|
||||
@@ -98,16 +98,13 @@ impl ToolHandler for PlanHandler {
|
||||
/// than forcing it to come up and document a plan (TBD how that affects performance).
|
||||
pub(crate) async fn handle_update_plan(
|
||||
session: &Session,
|
||||
turn_context: &TurnContext,
|
||||
arguments: String,
|
||||
sub_id: String,
|
||||
_call_id: String,
|
||||
) -> Result<String, FunctionCallError> {
|
||||
let args = parse_update_plan_arguments(&arguments)?;
|
||||
session
|
||||
.send_event(Event {
|
||||
id: sub_id.to_string(),
|
||||
msg: EventMsg::PlanUpdate(args),
|
||||
})
|
||||
.send_event(turn_context, EventMsg::PlanUpdate(args))
|
||||
.await;
|
||||
Ok("Plan updated".to_string())
|
||||
}
|
||||
|
||||
@@ -47,7 +47,6 @@ impl ToolHandler for ShellHandler {
|
||||
session,
|
||||
turn,
|
||||
tracker,
|
||||
sub_id,
|
||||
call_id,
|
||||
tool_name,
|
||||
payload,
|
||||
@@ -68,7 +67,6 @@ impl ToolHandler for ShellHandler {
|
||||
Arc::clone(&session),
|
||||
Arc::clone(&turn),
|
||||
Arc::clone(&tracker),
|
||||
sub_id.clone(),
|
||||
call_id.clone(),
|
||||
)
|
||||
.await?;
|
||||
@@ -85,7 +83,6 @@ impl ToolHandler for ShellHandler {
|
||||
Arc::clone(&session),
|
||||
Arc::clone(&turn),
|
||||
Arc::clone(&tracker),
|
||||
sub_id.clone(),
|
||||
call_id.clone(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
@@ -37,7 +37,6 @@ impl ToolHandler for UnifiedExecHandler {
|
||||
let ToolInvocation {
|
||||
session,
|
||||
turn,
|
||||
sub_id,
|
||||
call_id,
|
||||
tool_name: _tool_name,
|
||||
payload,
|
||||
@@ -91,7 +90,6 @@ impl ToolHandler for UnifiedExecHandler {
|
||||
crate::unified_exec::UnifiedExecContext {
|
||||
session: &session,
|
||||
turn: turn.as_ref(),
|
||||
sub_id: &sub_id,
|
||||
call_id: &call_id,
|
||||
session_id: parsed_session_id,
|
||||
},
|
||||
|
||||
@@ -3,7 +3,6 @@ use serde::Deserialize;
|
||||
use tokio::fs;
|
||||
|
||||
use crate::function_tool::FunctionCallError;
|
||||
use crate::protocol::Event;
|
||||
use crate::protocol::EventMsg;
|
||||
use crate::protocol::ViewImageToolCallEvent;
|
||||
use crate::tools::context::ToolInvocation;
|
||||
@@ -31,7 +30,6 @@ impl ToolHandler for ViewImageHandler {
|
||||
session,
|
||||
turn,
|
||||
payload,
|
||||
sub_id,
|
||||
call_id,
|
||||
..
|
||||
} = invocation;
|
||||
@@ -76,13 +74,13 @@ impl ToolHandler for ViewImageHandler {
|
||||
})?;
|
||||
|
||||
session
|
||||
.send_event(Event {
|
||||
id: sub_id.to_string(),
|
||||
msg: EventMsg::ViewImageToolCall(ViewImageToolCallEvent {
|
||||
.send_event(
|
||||
turn.as_ref(),
|
||||
EventMsg::ViewImageToolCall(ViewImageToolCallEvent {
|
||||
call_id,
|
||||
path: event_path,
|
||||
}),
|
||||
})
|
||||
)
|
||||
.await;
|
||||
|
||||
Ok(ToolOutput::Function {
|
||||
|
||||
@@ -61,7 +61,6 @@ pub(crate) async fn handle_container_exec_with_params(
|
||||
sess: Arc<Session>,
|
||||
turn_context: Arc<TurnContext>,
|
||||
turn_diff_tracker: SharedTurnDiffTracker,
|
||||
sub_id: String,
|
||||
call_id: String,
|
||||
) -> Result<String, FunctionCallError> {
|
||||
let _otel_event_manager = turn_context.client.get_otel_event_manager();
|
||||
@@ -78,14 +77,8 @@ pub(crate) async fn handle_container_exec_with_params(
|
||||
// check if this was a patch, and apply it if so
|
||||
let apply_patch_exec = match maybe_parse_apply_patch_verified(¶ms.command, ¶ms.cwd) {
|
||||
MaybeApplyPatchVerified::Body(changes) => {
|
||||
match apply_patch::apply_patch(
|
||||
sess.as_ref(),
|
||||
turn_context.as_ref(),
|
||||
&sub_id,
|
||||
&call_id,
|
||||
changes,
|
||||
)
|
||||
.await
|
||||
match apply_patch::apply_patch(sess.as_ref(), turn_context.as_ref(), &call_id, changes)
|
||||
.await
|
||||
{
|
||||
InternalApplyPatchInvocation::Output(item) => return item,
|
||||
InternalApplyPatchInvocation::DelegateToExec(apply_patch_exec) => {
|
||||
@@ -122,7 +115,7 @@ pub(crate) async fn handle_container_exec_with_params(
|
||||
),
|
||||
};
|
||||
|
||||
let event_ctx = ToolEventCtx::new(sess.as_ref(), &sub_id, &call_id, diff_opt);
|
||||
let event_ctx = ToolEventCtx::new(sess.as_ref(), turn_context.as_ref(), &call_id, diff_opt);
|
||||
event_emitter.emit(event_ctx, ToolEventStage::Begin).await;
|
||||
|
||||
// Build runtime contexts only when needed (shell/apply_patch below).
|
||||
@@ -141,7 +134,7 @@ pub(crate) async fn handle_container_exec_with_params(
|
||||
let mut runtime = ApplyPatchRuntime::new();
|
||||
let tool_ctx = ToolCtx {
|
||||
session: sess.as_ref(),
|
||||
sub_id: sub_id.clone(),
|
||||
turn: turn_context.as_ref(),
|
||||
call_id: call_id.clone(),
|
||||
tool_name: tool_name.to_string(),
|
||||
};
|
||||
@@ -172,7 +165,7 @@ pub(crate) async fn handle_container_exec_with_params(
|
||||
let mut runtime = ShellRuntime::new();
|
||||
let tool_ctx = ToolCtx {
|
||||
session: sess.as_ref(),
|
||||
sub_id: sub_id.clone(),
|
||||
turn: turn_context.as_ref(),
|
||||
call_id: call_id.clone(),
|
||||
tool_name: tool_name.to_string(),
|
||||
};
|
||||
|
||||
@@ -53,7 +53,7 @@ impl ToolOrchestrator {
|
||||
if needs_initial_approval {
|
||||
let approval_ctx = ApprovalCtx {
|
||||
session: tool_ctx.session,
|
||||
sub_id: &tool_ctx.sub_id,
|
||||
turn: turn_ctx,
|
||||
call_id: &tool_ctx.call_id,
|
||||
retry_reason: None,
|
||||
};
|
||||
@@ -110,7 +110,7 @@ impl ToolOrchestrator {
|
||||
let reason_msg = build_denial_reason_from_output(output.as_ref());
|
||||
let approval_ctx = ApprovalCtx {
|
||||
session: tool_ctx.session,
|
||||
sub_id: &tool_ctx.sub_id,
|
||||
turn: turn_ctx,
|
||||
call_id: &tool_ctx.call_id,
|
||||
retry_reason: Some(reason_msg),
|
||||
};
|
||||
|
||||
@@ -18,7 +18,6 @@ pub(crate) struct ToolCallRuntime {
|
||||
session: Arc<Session>,
|
||||
turn_context: Arc<TurnContext>,
|
||||
tracker: SharedTurnDiffTracker,
|
||||
sub_id: String,
|
||||
parallel_execution: Arc<RwLock<()>>,
|
||||
}
|
||||
|
||||
@@ -28,14 +27,12 @@ impl ToolCallRuntime {
|
||||
session: Arc<Session>,
|
||||
turn_context: Arc<TurnContext>,
|
||||
tracker: SharedTurnDiffTracker,
|
||||
sub_id: String,
|
||||
) -> Self {
|
||||
Self {
|
||||
router,
|
||||
session,
|
||||
turn_context,
|
||||
tracker,
|
||||
sub_id,
|
||||
parallel_execution: Arc::new(RwLock::new(())),
|
||||
}
|
||||
}
|
||||
@@ -50,7 +47,6 @@ impl ToolCallRuntime {
|
||||
let session = Arc::clone(&self.session);
|
||||
let turn = Arc::clone(&self.turn_context);
|
||||
let tracker = Arc::clone(&self.tracker);
|
||||
let sub_id = self.sub_id.clone();
|
||||
let lock = Arc::clone(&self.parallel_execution);
|
||||
|
||||
let handle: AbortOnDropHandle<Result<ResponseInputItem, FunctionCallError>> =
|
||||
@@ -62,7 +58,7 @@ impl ToolCallRuntime {
|
||||
};
|
||||
|
||||
router
|
||||
.dispatch_tool_call(session, turn, tracker, sub_id, call)
|
||||
.dispatch_tool_call(session, turn, tracker, call)
|
||||
.await
|
||||
}));
|
||||
|
||||
|
||||
@@ -134,7 +134,6 @@ impl ToolRouter {
|
||||
session: Arc<Session>,
|
||||
turn: Arc<TurnContext>,
|
||||
tracker: SharedTurnDiffTracker,
|
||||
sub_id: String,
|
||||
call: ToolCall,
|
||||
) -> Result<ResponseInputItem, FunctionCallError> {
|
||||
let ToolCall {
|
||||
@@ -149,7 +148,6 @@ impl ToolRouter {
|
||||
session,
|
||||
turn,
|
||||
tracker,
|
||||
sub_id,
|
||||
call_id,
|
||||
tool_name,
|
||||
payload,
|
||||
|
||||
@@ -68,7 +68,7 @@ impl ApplyPatchRuntime {
|
||||
|
||||
fn stdout_stream(ctx: &ToolCtx<'_>) -> Option<crate::exec::StdoutStream> {
|
||||
Some(crate::exec::StdoutStream {
|
||||
sub_id: ctx.sub_id.clone(),
|
||||
sub_id: ctx.turn.sub_id.clone(),
|
||||
call_id: ctx.call_id.clone(),
|
||||
tx_event: ctx.session.get_tx_event(),
|
||||
})
|
||||
@@ -101,7 +101,7 @@ impl Approvable<ApplyPatchRequest> for ApplyPatchRuntime {
|
||||
) -> BoxFuture<'a, ReviewDecision> {
|
||||
let key = self.approval_key(req);
|
||||
let session = ctx.session;
|
||||
let sub_id = ctx.sub_id.to_string();
|
||||
let turn = ctx.turn;
|
||||
let call_id = ctx.call_id.to_string();
|
||||
let cwd = req.cwd.clone();
|
||||
let retry_reason = ctx.retry_reason.clone();
|
||||
@@ -111,7 +111,7 @@ impl Approvable<ApplyPatchRequest> for ApplyPatchRuntime {
|
||||
if let Some(reason) = retry_reason {
|
||||
session
|
||||
.request_command_approval(
|
||||
sub_id,
|
||||
turn,
|
||||
call_id,
|
||||
vec!["apply_patch".to_string()],
|
||||
cwd,
|
||||
|
||||
@@ -51,7 +51,7 @@ impl ShellRuntime {
|
||||
|
||||
fn stdout_stream(ctx: &ToolCtx<'_>) -> Option<crate::exec::StdoutStream> {
|
||||
Some(crate::exec::StdoutStream {
|
||||
sub_id: ctx.sub_id.clone(),
|
||||
sub_id: ctx.turn.sub_id.clone(),
|
||||
call_id: ctx.call_id.clone(),
|
||||
tx_event: ctx.session.get_tx_event(),
|
||||
})
|
||||
@@ -91,12 +91,12 @@ impl Approvable<ShellRequest> for ShellRuntime {
|
||||
.clone()
|
||||
.or_else(|| req.justification.clone());
|
||||
let session = ctx.session;
|
||||
let sub_id = ctx.sub_id.to_string();
|
||||
let turn = ctx.turn;
|
||||
let call_id = ctx.call_id.to_string();
|
||||
Box::pin(async move {
|
||||
with_cached_approval(&session.services, key, || async move {
|
||||
session
|
||||
.request_command_approval(sub_id, call_id, command, cwd, reason)
|
||||
.request_command_approval(turn, call_id, command, cwd, reason)
|
||||
.await
|
||||
})
|
||||
.await
|
||||
|
||||
@@ -80,7 +80,7 @@ impl Approvable<UnifiedExecRequest> for UnifiedExecRuntime<'_> {
|
||||
) -> BoxFuture<'b, ReviewDecision> {
|
||||
let key = self.approval_key(req);
|
||||
let session = ctx.session;
|
||||
let sub_id = ctx.sub_id.to_string();
|
||||
let turn = ctx.turn;
|
||||
let call_id = ctx.call_id.to_string();
|
||||
let command = req.command.clone();
|
||||
let cwd = req.cwd.clone();
|
||||
@@ -88,7 +88,7 @@ impl Approvable<UnifiedExecRequest> for UnifiedExecRuntime<'_> {
|
||||
Box::pin(async move {
|
||||
with_cached_approval(&session.services, key, || async move {
|
||||
session
|
||||
.request_command_approval(sub_id, call_id, command, cwd, reason)
|
||||
.request_command_approval(turn, call_id, command, cwd, reason)
|
||||
.await
|
||||
})
|
||||
.await
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
//! and helpers (`Sandboxable`, `ToolRuntime`, `SandboxAttempt`, etc.).
|
||||
|
||||
use crate::codex::Session;
|
||||
use crate::codex::TurnContext;
|
||||
use crate::error::CodexErr;
|
||||
use crate::protocol::SandboxPolicy;
|
||||
use crate::sandboxing::CommandSpec;
|
||||
@@ -77,7 +78,7 @@ where
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct ApprovalCtx<'a> {
|
||||
pub session: &'a Session,
|
||||
pub sub_id: &'a str,
|
||||
pub turn: &'a TurnContext,
|
||||
pub call_id: &'a str,
|
||||
pub retry_reason: Option<String>,
|
||||
}
|
||||
@@ -145,7 +146,7 @@ pub(crate) trait Sandboxable {
|
||||
|
||||
pub(crate) struct ToolCtx<'a> {
|
||||
pub session: &'a Session,
|
||||
pub sub_id: String,
|
||||
pub turn: &'a TurnContext,
|
||||
pub call_id: String,
|
||||
pub tool_name: String,
|
||||
}
|
||||
|
||||
@@ -43,7 +43,6 @@ const UNIFIED_EXEC_OUTPUT_MAX_BYTES: usize = 128 * 1024; // 128 KiB
|
||||
pub(crate) struct UnifiedExecContext<'a> {
|
||||
pub session: &'a Session,
|
||||
pub turn: &'a TurnContext,
|
||||
pub sub_id: &'a str,
|
||||
pub call_id: &'a str,
|
||||
pub session_id: Option<i32>,
|
||||
}
|
||||
@@ -110,7 +109,6 @@ mod tests {
|
||||
UnifiedExecContext {
|
||||
session,
|
||||
turn: turn.as_ref(),
|
||||
sub_id: "sub",
|
||||
call_id: "call",
|
||||
session_id,
|
||||
},
|
||||
|
||||
@@ -113,7 +113,7 @@ impl UnifiedExecSessionManager {
|
||||
);
|
||||
let tool_ctx = ToolCtx {
|
||||
session: context.session,
|
||||
sub_id: context.sub_id.to_string(),
|
||||
turn: context.turn,
|
||||
call_id: context.call_id.to_string(),
|
||||
tool_name: "unified_exec".to_string(),
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user