Pass TurnContext around instead of sub_id (#5421)
Today `sub_id` is an ID of a single incoming Codex Op submition. We then associate all events triggered by this operation using the same `sub_id`. At the same time we are also creating a TurnContext per submission and we'd like to start associating some events (item added/item completed) with an entire turn instead of just the operation that started it. Using turn context when sending events give us flexibility to change notification scheme.
This commit is contained in:
@@ -36,7 +36,6 @@ pub(crate) struct ApplyPatchExec {
|
|||||||
pub(crate) async fn apply_patch(
|
pub(crate) async fn apply_patch(
|
||||||
sess: &Session,
|
sess: &Session,
|
||||||
turn_context: &TurnContext,
|
turn_context: &TurnContext,
|
||||||
sub_id: &str,
|
|
||||||
call_id: &str,
|
call_id: &str,
|
||||||
action: ApplyPatchAction,
|
action: ApplyPatchAction,
|
||||||
) -> InternalApplyPatchInvocation {
|
) -> InternalApplyPatchInvocation {
|
||||||
@@ -62,7 +61,7 @@ pub(crate) async fn apply_patch(
|
|||||||
// that similar patches can be auto-approved in the future during
|
// that similar patches can be auto-approved in the future during
|
||||||
// this session.
|
// this session.
|
||||||
let rx_approve = sess
|
let rx_approve = sess
|
||||||
.request_patch_approval(sub_id.to_owned(), call_id.to_owned(), &action, None, None)
|
.request_patch_approval(turn_context, call_id.to_owned(), &action, None, None)
|
||||||
.await;
|
.await;
|
||||||
match rx_approve.await.unwrap_or_default() {
|
match rx_approve.await.unwrap_or_default() {
|
||||||
ReviewDecision::Approved | ReviewDecision::ApprovedForSession => {
|
ReviewDecision::Approved | ReviewDecision::ApprovedForSession => {
|
||||||
|
|||||||
@@ -255,6 +255,7 @@ pub(crate) struct Session {
|
|||||||
/// The context needed for a single turn of the conversation.
|
/// The context needed for a single turn of the conversation.
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub(crate) struct TurnContext {
|
pub(crate) struct TurnContext {
|
||||||
|
pub(crate) sub_id: String,
|
||||||
pub(crate) client: ModelClient,
|
pub(crate) client: ModelClient,
|
||||||
/// The session's current working directory. All relative paths provided by
|
/// The session's current working directory. All relative paths provided by
|
||||||
/// the model as well as sandbox policies are resolved against this path
|
/// the model as well as sandbox policies are resolved against this path
|
||||||
@@ -359,6 +360,7 @@ impl Session {
|
|||||||
session_configuration: &SessionConfiguration,
|
session_configuration: &SessionConfiguration,
|
||||||
conversation_id: ConversationId,
|
conversation_id: ConversationId,
|
||||||
tx_event: Sender<Event>,
|
tx_event: Sender<Event>,
|
||||||
|
sub_id: String,
|
||||||
) -> TurnContext {
|
) -> TurnContext {
|
||||||
let config = session_configuration.original_config_do_not_use.clone();
|
let config = session_configuration.original_config_do_not_use.clone();
|
||||||
let model_family = find_family_for_model(&session_configuration.model)
|
let model_family = find_family_for_model(&session_configuration.model)
|
||||||
@@ -392,7 +394,10 @@ impl Session {
|
|||||||
features: &config.features,
|
features: &config.features,
|
||||||
});
|
});
|
||||||
|
|
||||||
|
let item_collector = ItemCollector::new(tx_event, conversation_id, sub_id.clone());
|
||||||
|
|
||||||
TurnContext {
|
TurnContext {
|
||||||
|
sub_id,
|
||||||
client,
|
client,
|
||||||
cwd: session_configuration.cwd.clone(),
|
cwd: session_configuration.cwd.clone(),
|
||||||
base_instructions: session_configuration.base_instructions.clone(),
|
base_instructions: session_configuration.base_instructions.clone(),
|
||||||
@@ -404,7 +409,7 @@ impl Session {
|
|||||||
is_review_mode: false,
|
is_review_mode: false,
|
||||||
final_output_json_schema: None,
|
final_output_json_schema: None,
|
||||||
codex_linux_sandbox_exe: config.codex_linux_sandbox_exe.clone(),
|
codex_linux_sandbox_exe: config.codex_linux_sandbox_exe.clone(),
|
||||||
item_collector: ItemCollector::new(tx_event, conversation_id, "turn_id".to_string()),
|
item_collector,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -587,7 +592,7 @@ impl Session {
|
|||||||
})
|
})
|
||||||
.chain(post_session_configured_error_events.into_iter());
|
.chain(post_session_configured_error_events.into_iter());
|
||||||
for event in events {
|
for event in events {
|
||||||
sess.send_event(event).await;
|
sess.send_event_raw(event).await;
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(sess)
|
Ok(sess)
|
||||||
@@ -638,6 +643,15 @@ impl Session {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) async fn new_turn(&self, updates: SessionSettingsUpdate) -> Arc<TurnContext> {
|
pub(crate) async fn new_turn(&self, updates: SessionSettingsUpdate) -> Arc<TurnContext> {
|
||||||
|
let sub_id = self.next_internal_sub_id();
|
||||||
|
self.new_turn_with_sub_id(sub_id, updates).await
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) async fn new_turn_with_sub_id(
|
||||||
|
&self,
|
||||||
|
sub_id: String,
|
||||||
|
updates: SessionSettingsUpdate,
|
||||||
|
) -> Arc<TurnContext> {
|
||||||
let session_configuration = {
|
let session_configuration = {
|
||||||
let mut state = self.state.lock().await;
|
let mut state = self.state.lock().await;
|
||||||
let session_configuration = state.session_configuration.clone().apply(&updates);
|
let session_configuration = state.session_configuration.clone().apply(&updates);
|
||||||
@@ -652,6 +666,7 @@ impl Session {
|
|||||||
&session_configuration,
|
&session_configuration,
|
||||||
self.conversation_id,
|
self.conversation_id,
|
||||||
self.get_tx_event(),
|
self.get_tx_event(),
|
||||||
|
sub_id,
|
||||||
);
|
);
|
||||||
if let Some(final_schema) = updates.final_output_json_schema {
|
if let Some(final_schema) = updates.final_output_json_schema {
|
||||||
turn_context.final_output_json_schema = final_schema;
|
turn_context.final_output_json_schema = final_schema;
|
||||||
@@ -678,7 +693,15 @@ impl Session {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Persist the event to rollout and send it to clients.
|
/// Persist the event to rollout and send it to clients.
|
||||||
pub(crate) async fn send_event(&self, event: Event) {
|
pub(crate) async fn send_event(&self, turn_context: &TurnContext, msg: EventMsg) {
|
||||||
|
let event = Event {
|
||||||
|
id: turn_context.sub_id.clone(),
|
||||||
|
msg,
|
||||||
|
};
|
||||||
|
self.send_event_raw(event).await;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) async fn send_event_raw(&self, event: Event) {
|
||||||
// Persist the event into rollout (recorder filters as needed)
|
// Persist the event into rollout (recorder filters as needed)
|
||||||
let rollout_items = vec![RolloutItem::EventMsg(event.msg.clone())];
|
let rollout_items = vec![RolloutItem::EventMsg(event.msg.clone())];
|
||||||
self.persist_rollout_items(&rollout_items).await;
|
self.persist_rollout_items(&rollout_items).await;
|
||||||
@@ -694,12 +717,13 @@ impl Session {
|
|||||||
/// default `ReviewDecision` (`Denied`).
|
/// default `ReviewDecision` (`Denied`).
|
||||||
pub async fn request_command_approval(
|
pub async fn request_command_approval(
|
||||||
&self,
|
&self,
|
||||||
sub_id: String,
|
turn_context: &TurnContext,
|
||||||
call_id: String,
|
call_id: String,
|
||||||
command: Vec<String>,
|
command: Vec<String>,
|
||||||
cwd: PathBuf,
|
cwd: PathBuf,
|
||||||
reason: Option<String>,
|
reason: Option<String>,
|
||||||
) -> ReviewDecision {
|
) -> ReviewDecision {
|
||||||
|
let sub_id = turn_context.sub_id.clone();
|
||||||
// Add the tx_approve callback to the map before sending the request.
|
// Add the tx_approve callback to the map before sending the request.
|
||||||
let (tx_approve, rx_approve) = oneshot::channel();
|
let (tx_approve, rx_approve) = oneshot::channel();
|
||||||
let event_id = sub_id.clone();
|
let event_id = sub_id.clone();
|
||||||
@@ -718,28 +742,26 @@ impl Session {
|
|||||||
}
|
}
|
||||||
|
|
||||||
let parsed_cmd = parse_command(&command);
|
let parsed_cmd = parse_command(&command);
|
||||||
let event = Event {
|
let event = EventMsg::ExecApprovalRequest(ExecApprovalRequestEvent {
|
||||||
id: event_id,
|
call_id,
|
||||||
msg: EventMsg::ExecApprovalRequest(ExecApprovalRequestEvent {
|
command,
|
||||||
call_id,
|
cwd,
|
||||||
command,
|
reason,
|
||||||
cwd,
|
parsed_cmd,
|
||||||
reason,
|
});
|
||||||
parsed_cmd,
|
self.send_event(turn_context, event).await;
|
||||||
}),
|
|
||||||
};
|
|
||||||
self.send_event(event).await;
|
|
||||||
rx_approve.await.unwrap_or_default()
|
rx_approve.await.unwrap_or_default()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn request_patch_approval(
|
pub async fn request_patch_approval(
|
||||||
&self,
|
&self,
|
||||||
sub_id: String,
|
turn_context: &TurnContext,
|
||||||
call_id: String,
|
call_id: String,
|
||||||
action: &ApplyPatchAction,
|
action: &ApplyPatchAction,
|
||||||
reason: Option<String>,
|
reason: Option<String>,
|
||||||
grant_root: Option<PathBuf>,
|
grant_root: Option<PathBuf>,
|
||||||
) -> oneshot::Receiver<ReviewDecision> {
|
) -> oneshot::Receiver<ReviewDecision> {
|
||||||
|
let sub_id = turn_context.sub_id.clone();
|
||||||
// Add the tx_approve callback to the map before sending the request.
|
// Add the tx_approve callback to the map before sending the request.
|
||||||
let (tx_approve, rx_approve) = oneshot::channel();
|
let (tx_approve, rx_approve) = oneshot::channel();
|
||||||
let event_id = sub_id.clone();
|
let event_id = sub_id.clone();
|
||||||
@@ -757,16 +779,13 @@ impl Session {
|
|||||||
warn!("Overwriting existing pending approval for sub_id: {event_id}");
|
warn!("Overwriting existing pending approval for sub_id: {event_id}");
|
||||||
}
|
}
|
||||||
|
|
||||||
let event = Event {
|
let event = EventMsg::ApplyPatchApprovalRequest(ApplyPatchApprovalRequestEvent {
|
||||||
id: event_id,
|
call_id,
|
||||||
msg: EventMsg::ApplyPatchApprovalRequest(ApplyPatchApprovalRequestEvent {
|
changes: convert_apply_patch_to_protocol(action),
|
||||||
call_id,
|
reason,
|
||||||
changes: convert_apply_patch_to_protocol(action),
|
grant_root,
|
||||||
reason,
|
});
|
||||||
grant_root,
|
self.send_event(turn_context, event).await;
|
||||||
}),
|
|
||||||
};
|
|
||||||
self.send_event(event).await;
|
|
||||||
rx_approve
|
rx_approve
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -878,7 +897,6 @@ impl Session {
|
|||||||
|
|
||||||
async fn update_token_usage_info(
|
async fn update_token_usage_info(
|
||||||
&self,
|
&self,
|
||||||
sub_id: &str,
|
|
||||||
turn_context: &TurnContext,
|
turn_context: &TurnContext,
|
||||||
token_usage: Option<&TokenUsage>,
|
token_usage: Option<&TokenUsage>,
|
||||||
) {
|
) {
|
||||||
@@ -891,37 +909,38 @@ impl Session {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
self.send_token_count_event(sub_id).await;
|
self.send_token_count_event(turn_context).await;
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn update_rate_limits(&self, sub_id: &str, new_rate_limits: RateLimitSnapshot) {
|
async fn update_rate_limits(
|
||||||
|
&self,
|
||||||
|
turn_context: &TurnContext,
|
||||||
|
new_rate_limits: RateLimitSnapshot,
|
||||||
|
) {
|
||||||
{
|
{
|
||||||
let mut state = self.state.lock().await;
|
let mut state = self.state.lock().await;
|
||||||
state.set_rate_limits(new_rate_limits);
|
state.set_rate_limits(new_rate_limits);
|
||||||
}
|
}
|
||||||
self.send_token_count_event(sub_id).await;
|
self.send_token_count_event(turn_context).await;
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn send_token_count_event(&self, sub_id: &str) {
|
async fn send_token_count_event(&self, turn_context: &TurnContext) {
|
||||||
let (info, rate_limits) = {
|
let (info, rate_limits) = {
|
||||||
let state = self.state.lock().await;
|
let state = self.state.lock().await;
|
||||||
state.token_info_and_rate_limits()
|
state.token_info_and_rate_limits()
|
||||||
};
|
};
|
||||||
let event = Event {
|
let event = EventMsg::TokenCount(TokenCountEvent { info, rate_limits });
|
||||||
id: sub_id.to_string(),
|
self.send_event(turn_context, event).await;
|
||||||
msg: EventMsg::TokenCount(TokenCountEvent { info, rate_limits }),
|
|
||||||
};
|
|
||||||
self.send_event(event).await;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn set_total_tokens_full(&self, sub_id: &str, turn_context: &TurnContext) {
|
async fn set_total_tokens_full(&self, turn_context: &TurnContext) {
|
||||||
let context_window = turn_context.client.get_model_context_window();
|
let context_window = turn_context.client.get_model_context_window();
|
||||||
if let Some(context_window) = context_window {
|
if let Some(context_window) = context_window {
|
||||||
{
|
{
|
||||||
let mut state = self.state.lock().await;
|
let mut state = self.state.lock().await;
|
||||||
state.set_token_usage_full(context_window);
|
state.set_token_usage_full(context_window);
|
||||||
}
|
}
|
||||||
self.send_token_count_event(sub_id).await;
|
self.send_token_count_event(turn_context).await;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -951,24 +970,22 @@ impl Session {
|
|||||||
/// Helper that emits a BackgroundEvent with the given message. This keeps
|
/// Helper that emits a BackgroundEvent with the given message. This keeps
|
||||||
/// the call‑sites terse so adding more diagnostics does not clutter the
|
/// the call‑sites terse so adding more diagnostics does not clutter the
|
||||||
/// core agent logic.
|
/// core agent logic.
|
||||||
pub(crate) async fn notify_background_event(&self, sub_id: &str, message: impl Into<String>) {
|
pub(crate) async fn notify_background_event(
|
||||||
let event = Event {
|
&self,
|
||||||
id: sub_id.to_string(),
|
turn_context: &TurnContext,
|
||||||
msg: EventMsg::BackgroundEvent(BackgroundEventEvent {
|
message: impl Into<String>,
|
||||||
message: message.into(),
|
) {
|
||||||
}),
|
let event = EventMsg::BackgroundEvent(BackgroundEventEvent {
|
||||||
};
|
message: message.into(),
|
||||||
self.send_event(event).await;
|
});
|
||||||
|
self.send_event(turn_context, event).await;
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn notify_stream_error(&self, sub_id: &str, message: impl Into<String>) {
|
async fn notify_stream_error(&self, turn_context: &TurnContext, message: impl Into<String>) {
|
||||||
let event = Event {
|
let event = EventMsg::StreamError(StreamErrorEvent {
|
||||||
id: sub_id.to_string(),
|
message: message.into(),
|
||||||
msg: EventMsg::StreamError(StreamErrorEvent {
|
});
|
||||||
message: message.into(),
|
self.send_event(turn_context, event).await;
|
||||||
}),
|
|
||||||
};
|
|
||||||
self.send_event(event).await;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Build the full turn input by concatenating the current conversation
|
/// Build the full turn input by concatenating the current conversation
|
||||||
@@ -1129,7 +1146,7 @@ async fn submission_loop(sess: Arc<Session>, config: Arc<Config>, rx_sub: Receiv
|
|||||||
Op::UserInput { items } => (items, SessionSettingsUpdate::default()),
|
Op::UserInput { items } => (items, SessionSettingsUpdate::default()),
|
||||||
_ => unreachable!(),
|
_ => unreachable!(),
|
||||||
};
|
};
|
||||||
let current_context = sess.new_turn(updates).await;
|
let current_context = sess.new_turn_with_sub_id(sub.id.clone(), updates).await;
|
||||||
current_context
|
current_context
|
||||||
.client
|
.client
|
||||||
.get_otel_event_manager()
|
.get_otel_event_manager()
|
||||||
@@ -1145,11 +1162,7 @@ async fn submission_loop(sess: Arc<Session>, config: Arc<Config>, rx_sub: Receiv
|
|||||||
&env_item,
|
&env_item,
|
||||||
sess.show_raw_agent_reasoning(),
|
sess.show_raw_agent_reasoning(),
|
||||||
) {
|
) {
|
||||||
let event = Event {
|
sess.send_event(¤t_context, msg).await;
|
||||||
id: sub.id.clone(),
|
|
||||||
msg,
|
|
||||||
};
|
|
||||||
sess.send_event(event).await;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1158,7 +1171,7 @@ async fn submission_loop(sess: Arc<Session>, config: Arc<Config>, rx_sub: Receiv
|
|||||||
.started_completed(TurnItem::UserMessage(UserMessageItem::new(&items)))
|
.started_completed(TurnItem::UserMessage(UserMessageItem::new(&items)))
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
sess.spawn_task(Arc::clone(¤t_context), sub.id, items, RegularTask)
|
sess.spawn_task(Arc::clone(¤t_context), items, RegularTask)
|
||||||
.await;
|
.await;
|
||||||
previous_context = Some(current_context);
|
previous_context = Some(current_context);
|
||||||
}
|
}
|
||||||
@@ -1216,7 +1229,7 @@ async fn submission_loop(sess: Arc<Session>, config: Arc<Config>, rx_sub: Receiv
|
|||||||
),
|
),
|
||||||
};
|
};
|
||||||
|
|
||||||
sess_clone.send_event(event).await;
|
sess_clone.send_event_raw(event).await;
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
Op::ListMcpTools => {
|
Op::ListMcpTools => {
|
||||||
@@ -1249,7 +1262,7 @@ async fn submission_loop(sess: Arc<Session>, config: Arc<Config>, rx_sub: Receiv
|
|||||||
},
|
},
|
||||||
),
|
),
|
||||||
};
|
};
|
||||||
sess.send_event(event).await;
|
sess.send_event_raw(event).await;
|
||||||
}
|
}
|
||||||
Op::ListCustomPrompts => {
|
Op::ListCustomPrompts => {
|
||||||
let sub_id = sub.id.clone();
|
let sub_id = sub.id.clone();
|
||||||
@@ -1267,10 +1280,12 @@ async fn submission_loop(sess: Arc<Session>, config: Arc<Config>, rx_sub: Receiv
|
|||||||
custom_prompts,
|
custom_prompts,
|
||||||
}),
|
}),
|
||||||
};
|
};
|
||||||
sess.send_event(event).await;
|
sess.send_event_raw(event).await;
|
||||||
}
|
}
|
||||||
Op::Compact => {
|
Op::Compact => {
|
||||||
let turn_context = sess.new_turn(SessionSettingsUpdate::default()).await;
|
let turn_context = sess
|
||||||
|
.new_turn_with_sub_id(sub.id.clone(), SessionSettingsUpdate::default())
|
||||||
|
.await;
|
||||||
// Attempt to inject input into current task
|
// Attempt to inject input into current task
|
||||||
if let Err(items) = sess
|
if let Err(items) = sess
|
||||||
.inject_input(vec![UserInput::Text {
|
.inject_input(vec![UserInput::Text {
|
||||||
@@ -1278,7 +1293,7 @@ async fn submission_loop(sess: Arc<Session>, config: Arc<Config>, rx_sub: Receiv
|
|||||||
}])
|
}])
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
sess.spawn_task(Arc::clone(&turn_context), sub.id, items, CompactTask)
|
sess.spawn_task(Arc::clone(&turn_context), items, CompactTask)
|
||||||
.await;
|
.await;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1302,14 +1317,14 @@ async fn submission_loop(sess: Arc<Session>, config: Arc<Config>, rx_sub: Receiv
|
|||||||
message: "Failed to shutdown rollout recorder".to_string(),
|
message: "Failed to shutdown rollout recorder".to_string(),
|
||||||
}),
|
}),
|
||||||
};
|
};
|
||||||
sess.send_event(event).await;
|
sess.send_event_raw(event).await;
|
||||||
}
|
}
|
||||||
|
|
||||||
let event = Event {
|
let event = Event {
|
||||||
id: sub.id.clone(),
|
id: sub.id.clone(),
|
||||||
msg: EventMsg::ShutdownComplete,
|
msg: EventMsg::ShutdownComplete,
|
||||||
};
|
};
|
||||||
sess.send_event(event).await;
|
sess.send_event_raw(event).await;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
Op::GetPath => {
|
Op::GetPath => {
|
||||||
@@ -1337,10 +1352,12 @@ async fn submission_loop(sess: Arc<Session>, config: Arc<Config>, rx_sub: Receiv
|
|||||||
path,
|
path,
|
||||||
}),
|
}),
|
||||||
};
|
};
|
||||||
sess.send_event(event).await;
|
sess.send_event_raw(event).await;
|
||||||
}
|
}
|
||||||
Op::Review { review_request } => {
|
Op::Review { review_request } => {
|
||||||
let turn_context = sess.new_turn(SessionSettingsUpdate::default()).await;
|
let turn_context = sess
|
||||||
|
.new_turn_with_sub_id(sub.id.clone(), SessionSettingsUpdate::default())
|
||||||
|
.await;
|
||||||
spawn_review_thread(
|
spawn_review_thread(
|
||||||
sess.clone(),
|
sess.clone(),
|
||||||
config.clone(),
|
config.clone(),
|
||||||
@@ -1416,6 +1433,7 @@ async fn spawn_review_thread(
|
|||||||
);
|
);
|
||||||
|
|
||||||
let review_turn_context = TurnContext {
|
let review_turn_context = TurnContext {
|
||||||
|
sub_id: sub_id.to_string(),
|
||||||
client,
|
client,
|
||||||
tools_config,
|
tools_config,
|
||||||
user_instructions: None,
|
user_instructions: None,
|
||||||
@@ -1439,17 +1457,11 @@ async fn spawn_review_thread(
|
|||||||
text: review_prompt,
|
text: review_prompt,
|
||||||
}];
|
}];
|
||||||
let tc = Arc::new(review_turn_context);
|
let tc = Arc::new(review_turn_context);
|
||||||
|
sess.spawn_task(tc.clone(), input, ReviewTask).await;
|
||||||
// Clone sub_id for the upcoming announcement before moving it into the task.
|
|
||||||
let sub_id_for_event = sub_id.clone();
|
|
||||||
sess.spawn_task(tc.clone(), sub_id, input, ReviewTask).await;
|
|
||||||
|
|
||||||
// Announce entering review mode so UIs can switch modes.
|
// Announce entering review mode so UIs can switch modes.
|
||||||
sess.send_event(Event {
|
sess.send_event(&tc, EventMsg::EnteredReviewMode(review_request))
|
||||||
id: sub_id_for_event,
|
.await;
|
||||||
msg: EventMsg::EnteredReviewMode(review_request),
|
|
||||||
})
|
|
||||||
.await;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Takes a user message as input and runs a loop where, at each turn, the model
|
/// Takes a user message as input and runs a loop where, at each turn, the model
|
||||||
@@ -1472,7 +1484,6 @@ async fn spawn_review_thread(
|
|||||||
pub(crate) async fn run_task(
|
pub(crate) async fn run_task(
|
||||||
sess: Arc<Session>,
|
sess: Arc<Session>,
|
||||||
turn_context: Arc<TurnContext>,
|
turn_context: Arc<TurnContext>,
|
||||||
sub_id: String,
|
|
||||||
input: Vec<UserInput>,
|
input: Vec<UserInput>,
|
||||||
task_kind: TaskKind,
|
task_kind: TaskKind,
|
||||||
cancellation_token: CancellationToken,
|
cancellation_token: CancellationToken,
|
||||||
@@ -1480,13 +1491,10 @@ pub(crate) async fn run_task(
|
|||||||
if input.is_empty() {
|
if input.is_empty() {
|
||||||
return None;
|
return None;
|
||||||
}
|
}
|
||||||
let event = Event {
|
let event = EventMsg::TaskStarted(TaskStartedEvent {
|
||||||
id: sub_id.clone(),
|
model_context_window: turn_context.client.get_model_context_window(),
|
||||||
msg: EventMsg::TaskStarted(TaskStartedEvent {
|
});
|
||||||
model_context_window: turn_context.client.get_model_context_window(),
|
sess.send_event(&turn_context, event).await;
|
||||||
}),
|
|
||||||
};
|
|
||||||
sess.send_event(event).await;
|
|
||||||
|
|
||||||
let initial_input_for_turn: ResponseInputItem = ResponseInputItem::from(input);
|
let initial_input_for_turn: ResponseInputItem = ResponseInputItem::from(input);
|
||||||
// For review threads, keep an isolated in-memory history so the
|
// For review threads, keep an isolated in-memory history so the
|
||||||
@@ -1557,7 +1565,6 @@ pub(crate) async fn run_task(
|
|||||||
Arc::clone(&sess),
|
Arc::clone(&sess),
|
||||||
Arc::clone(&turn_context),
|
Arc::clone(&turn_context),
|
||||||
Arc::clone(&turn_diff_tracker),
|
Arc::clone(&turn_diff_tracker),
|
||||||
sub_id.clone(),
|
|
||||||
turn_input,
|
turn_input,
|
||||||
task_kind,
|
task_kind,
|
||||||
cancellation_token.child_token(),
|
cancellation_token.child_token(),
|
||||||
@@ -1689,15 +1696,12 @@ pub(crate) async fn run_task(
|
|||||||
let current_tokens = total_usage_tokens
|
let current_tokens = total_usage_tokens
|
||||||
.map(|tokens| tokens.to_string())
|
.map(|tokens| tokens.to_string())
|
||||||
.unwrap_or_else(|| "unknown".to_string());
|
.unwrap_or_else(|| "unknown".to_string());
|
||||||
let event = Event {
|
let event = EventMsg::Error(ErrorEvent {
|
||||||
id: sub_id.clone(),
|
message: format!(
|
||||||
msg: EventMsg::Error(ErrorEvent {
|
"Conversation is still above the token limit after automatic summarization (limit {limit_str}, current {current_tokens}). Please start a new session or trim your input."
|
||||||
message: format!(
|
),
|
||||||
"Conversation is still above the token limit after automatic summarization (limit {limit_str}, current {current_tokens}). Please start a new session or trim your input."
|
});
|
||||||
),
|
sess.send_event(&turn_context, event).await;
|
||||||
}),
|
|
||||||
};
|
|
||||||
sess.send_event(event).await;
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
auto_compact_recently_attempted = true;
|
auto_compact_recently_attempted = true;
|
||||||
@@ -1714,7 +1718,7 @@ pub(crate) async fn run_task(
|
|||||||
sess.notifier()
|
sess.notifier()
|
||||||
.notify(&UserNotification::AgentTurnComplete {
|
.notify(&UserNotification::AgentTurnComplete {
|
||||||
thread_id: sess.conversation_id.to_string(),
|
thread_id: sess.conversation_id.to_string(),
|
||||||
turn_id: sub_id.clone(),
|
turn_id: turn_context.sub_id.clone(),
|
||||||
cwd: turn_context.cwd.display().to_string(),
|
cwd: turn_context.cwd.display().to_string(),
|
||||||
input_messages: turn_input_messages,
|
input_messages: turn_input_messages,
|
||||||
last_assistant_message: last_agent_message.clone(),
|
last_assistant_message: last_agent_message.clone(),
|
||||||
@@ -1729,13 +1733,10 @@ pub(crate) async fn run_task(
|
|||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
info!("Turn error: {e:#}");
|
info!("Turn error: {e:#}");
|
||||||
let event = Event {
|
let event = EventMsg::Error(ErrorEvent {
|
||||||
id: sub_id.clone(),
|
message: e.to_string(),
|
||||||
msg: EventMsg::Error(ErrorEvent {
|
});
|
||||||
message: e.to_string(),
|
sess.send_event(&turn_context, event).await;
|
||||||
}),
|
|
||||||
};
|
|
||||||
sess.send_event(event).await;
|
|
||||||
// let the user continue the conversation
|
// let the user continue the conversation
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@@ -1752,7 +1753,7 @@ pub(crate) async fn run_task(
|
|||||||
if turn_context.is_review_mode {
|
if turn_context.is_review_mode {
|
||||||
exit_review_mode(
|
exit_review_mode(
|
||||||
sess.clone(),
|
sess.clone(),
|
||||||
sub_id.clone(),
|
Arc::clone(&turn_context),
|
||||||
last_agent_message.as_deref().map(parse_review_output_event),
|
last_agent_message.as_deref().map(parse_review_output_event),
|
||||||
)
|
)
|
||||||
.await;
|
.await;
|
||||||
@@ -1790,7 +1791,6 @@ async fn run_turn(
|
|||||||
sess: Arc<Session>,
|
sess: Arc<Session>,
|
||||||
turn_context: Arc<TurnContext>,
|
turn_context: Arc<TurnContext>,
|
||||||
turn_diff_tracker: SharedTurnDiffTracker,
|
turn_diff_tracker: SharedTurnDiffTracker,
|
||||||
sub_id: String,
|
|
||||||
input: Vec<ResponseItem>,
|
input: Vec<ResponseItem>,
|
||||||
task_kind: TaskKind,
|
task_kind: TaskKind,
|
||||||
cancellation_token: CancellationToken,
|
cancellation_token: CancellationToken,
|
||||||
@@ -1821,7 +1821,6 @@ async fn run_turn(
|
|||||||
Arc::clone(&sess),
|
Arc::clone(&sess),
|
||||||
Arc::clone(&turn_context),
|
Arc::clone(&turn_context),
|
||||||
Arc::clone(&turn_diff_tracker),
|
Arc::clone(&turn_diff_tracker),
|
||||||
&sub_id,
|
|
||||||
&prompt,
|
&prompt,
|
||||||
task_kind,
|
task_kind,
|
||||||
cancellation_token.child_token(),
|
cancellation_token.child_token(),
|
||||||
@@ -1834,13 +1833,14 @@ async fn run_turn(
|
|||||||
Err(CodexErr::EnvVar(var)) => return Err(CodexErr::EnvVar(var)),
|
Err(CodexErr::EnvVar(var)) => return Err(CodexErr::EnvVar(var)),
|
||||||
Err(e @ CodexErr::Fatal(_)) => return Err(e),
|
Err(e @ CodexErr::Fatal(_)) => return Err(e),
|
||||||
Err(e @ CodexErr::ContextWindowExceeded) => {
|
Err(e @ CodexErr::ContextWindowExceeded) => {
|
||||||
sess.set_total_tokens_full(&sub_id, &turn_context).await;
|
sess.set_total_tokens_full(turn_context.as_ref()).await;
|
||||||
return Err(e);
|
return Err(e);
|
||||||
}
|
}
|
||||||
Err(CodexErr::UsageLimitReached(e)) => {
|
Err(CodexErr::UsageLimitReached(e)) => {
|
||||||
let rate_limits = e.rate_limits.clone();
|
let rate_limits = e.rate_limits.clone();
|
||||||
if let Some(rate_limits) = rate_limits {
|
if let Some(rate_limits) = rate_limits {
|
||||||
sess.update_rate_limits(&sub_id, rate_limits).await;
|
sess.update_rate_limits(turn_context.as_ref(), rate_limits)
|
||||||
|
.await;
|
||||||
}
|
}
|
||||||
return Err(CodexErr::UsageLimitReached(e));
|
return Err(CodexErr::UsageLimitReached(e));
|
||||||
}
|
}
|
||||||
@@ -1862,7 +1862,7 @@ async fn run_turn(
|
|||||||
// user understands what is happening instead of staring
|
// user understands what is happening instead of staring
|
||||||
// at a seemingly frozen screen.
|
// at a seemingly frozen screen.
|
||||||
sess.notify_stream_error(
|
sess.notify_stream_error(
|
||||||
&sub_id,
|
turn_context.as_ref(),
|
||||||
format!("Re-connecting... {retries}/{max_retries}"),
|
format!("Re-connecting... {retries}/{max_retries}"),
|
||||||
)
|
)
|
||||||
.await;
|
.await;
|
||||||
@@ -1898,7 +1898,6 @@ async fn try_run_turn(
|
|||||||
sess: Arc<Session>,
|
sess: Arc<Session>,
|
||||||
turn_context: Arc<TurnContext>,
|
turn_context: Arc<TurnContext>,
|
||||||
turn_diff_tracker: SharedTurnDiffTracker,
|
turn_diff_tracker: SharedTurnDiffTracker,
|
||||||
sub_id: &str,
|
|
||||||
prompt: &Prompt,
|
prompt: &Prompt,
|
||||||
task_kind: TaskKind,
|
task_kind: TaskKind,
|
||||||
cancellation_token: CancellationToken,
|
cancellation_token: CancellationToken,
|
||||||
@@ -1979,7 +1978,6 @@ async fn try_run_turn(
|
|||||||
Arc::clone(&sess),
|
Arc::clone(&sess),
|
||||||
Arc::clone(&turn_context),
|
Arc::clone(&turn_context),
|
||||||
Arc::clone(&turn_diff_tracker),
|
Arc::clone(&turn_diff_tracker),
|
||||||
sub_id.to_string(),
|
|
||||||
);
|
);
|
||||||
let mut output: FuturesOrdered<BoxFuture<CodexResult<ProcessedResponseItem>>> =
|
let mut output: FuturesOrdered<BoxFuture<CodexResult<ProcessedResponseItem>>> =
|
||||||
FuturesOrdered::new();
|
FuturesOrdered::new();
|
||||||
@@ -2028,7 +2026,6 @@ async fn try_run_turn(
|
|||||||
let response = handle_non_tool_response_item(
|
let response = handle_non_tool_response_item(
|
||||||
Arc::clone(&sess),
|
Arc::clone(&sess),
|
||||||
Arc::clone(&turn_context),
|
Arc::clone(&turn_context),
|
||||||
sub_id,
|
|
||||||
item.clone(),
|
item.clone(),
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
@@ -2077,7 +2074,7 @@ async fn try_run_turn(
|
|||||||
let _ = sess
|
let _ = sess
|
||||||
.tx_event
|
.tx_event
|
||||||
.send(Event {
|
.send(Event {
|
||||||
id: sub_id.to_string(),
|
id: turn_context.sub_id.clone(),
|
||||||
msg: EventMsg::WebSearchBegin(WebSearchBeginEvent { call_id }),
|
msg: EventMsg::WebSearchBegin(WebSearchBeginEvent { call_id }),
|
||||||
})
|
})
|
||||||
.await;
|
.await;
|
||||||
@@ -2085,13 +2082,14 @@ async fn try_run_turn(
|
|||||||
ResponseEvent::RateLimits(snapshot) => {
|
ResponseEvent::RateLimits(snapshot) => {
|
||||||
// Update internal state with latest rate limits, but defer sending until
|
// Update internal state with latest rate limits, but defer sending until
|
||||||
// token usage is available to avoid duplicate TokenCount events.
|
// token usage is available to avoid duplicate TokenCount events.
|
||||||
sess.update_rate_limits(sub_id, snapshot).await;
|
sess.update_rate_limits(turn_context.as_ref(), snapshot)
|
||||||
|
.await;
|
||||||
}
|
}
|
||||||
ResponseEvent::Completed {
|
ResponseEvent::Completed {
|
||||||
response_id: _,
|
response_id: _,
|
||||||
token_usage,
|
token_usage,
|
||||||
} => {
|
} => {
|
||||||
sess.update_token_usage_info(sub_id, turn_context.as_ref(), token_usage.as_ref())
|
sess.update_token_usage_info(turn_context.as_ref(), token_usage.as_ref())
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
let processed_items = output
|
let processed_items = output
|
||||||
@@ -2105,11 +2103,7 @@ async fn try_run_turn(
|
|||||||
};
|
};
|
||||||
if let Ok(Some(unified_diff)) = unified_diff {
|
if let Ok(Some(unified_diff)) = unified_diff {
|
||||||
let msg = EventMsg::TurnDiff(TurnDiffEvent { unified_diff });
|
let msg = EventMsg::TurnDiff(TurnDiffEvent { unified_diff });
|
||||||
let event = Event {
|
sess.send_event(&turn_context, msg).await;
|
||||||
id: sub_id.to_string(),
|
|
||||||
msg,
|
|
||||||
};
|
|
||||||
sess.send_event(event).await;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
let result = TurnRunResult {
|
let result = TurnRunResult {
|
||||||
@@ -2123,38 +2117,27 @@ async fn try_run_turn(
|
|||||||
// In review child threads, suppress assistant text deltas; the
|
// In review child threads, suppress assistant text deltas; the
|
||||||
// UI will show a selection popup from the final ReviewOutput.
|
// UI will show a selection popup from the final ReviewOutput.
|
||||||
if !turn_context.is_review_mode {
|
if !turn_context.is_review_mode {
|
||||||
let event = Event {
|
let event = EventMsg::AgentMessageDelta(AgentMessageDeltaEvent { delta });
|
||||||
id: sub_id.to_string(),
|
sess.send_event(&turn_context, event).await;
|
||||||
msg: EventMsg::AgentMessageDelta(AgentMessageDeltaEvent { delta }),
|
|
||||||
};
|
|
||||||
sess.send_event(event).await;
|
|
||||||
} else {
|
} else {
|
||||||
trace!("suppressing OutputTextDelta in review mode");
|
trace!("suppressing OutputTextDelta in review mode");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ResponseEvent::ReasoningSummaryDelta(delta) => {
|
ResponseEvent::ReasoningSummaryDelta(delta) => {
|
||||||
let event = Event {
|
let event = EventMsg::AgentReasoningDelta(AgentReasoningDeltaEvent { delta });
|
||||||
id: sub_id.to_string(),
|
sess.send_event(&turn_context, event).await;
|
||||||
msg: EventMsg::AgentReasoningDelta(AgentReasoningDeltaEvent { delta }),
|
|
||||||
};
|
|
||||||
sess.send_event(event).await;
|
|
||||||
}
|
}
|
||||||
ResponseEvent::ReasoningSummaryPartAdded => {
|
ResponseEvent::ReasoningSummaryPartAdded => {
|
||||||
let event = Event {
|
let event =
|
||||||
id: sub_id.to_string(),
|
EventMsg::AgentReasoningSectionBreak(AgentReasoningSectionBreakEvent {});
|
||||||
msg: EventMsg::AgentReasoningSectionBreak(AgentReasoningSectionBreakEvent {}),
|
sess.send_event(&turn_context, event).await;
|
||||||
};
|
|
||||||
sess.send_event(event).await;
|
|
||||||
}
|
}
|
||||||
ResponseEvent::ReasoningContentDelta(delta) => {
|
ResponseEvent::ReasoningContentDelta(delta) => {
|
||||||
if sess.show_raw_agent_reasoning() {
|
if sess.show_raw_agent_reasoning() {
|
||||||
let event = Event {
|
let event = EventMsg::AgentReasoningRawContentDelta(
|
||||||
id: sub_id.to_string(),
|
AgentReasoningRawContentDeltaEvent { delta },
|
||||||
msg: EventMsg::AgentReasoningRawContentDelta(
|
);
|
||||||
AgentReasoningRawContentDeltaEvent { delta },
|
sess.send_event(&turn_context, event).await;
|
||||||
),
|
|
||||||
};
|
|
||||||
sess.send_event(event).await;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -2164,7 +2147,6 @@ async fn try_run_turn(
|
|||||||
async fn handle_non_tool_response_item(
|
async fn handle_non_tool_response_item(
|
||||||
sess: Arc<Session>,
|
sess: Arc<Session>,
|
||||||
turn_context: Arc<TurnContext>,
|
turn_context: Arc<TurnContext>,
|
||||||
sub_id: &str,
|
|
||||||
item: ResponseItem,
|
item: ResponseItem,
|
||||||
) -> CodexResult<Option<ResponseInputItem>> {
|
) -> CodexResult<Option<ResponseInputItem>> {
|
||||||
debug!(?item, "Output item");
|
debug!(?item, "Output item");
|
||||||
@@ -2181,11 +2163,7 @@ async fn handle_non_tool_response_item(
|
|||||||
_ => map_response_item_to_event_messages(&item, sess.show_raw_agent_reasoning()),
|
_ => map_response_item_to_event_messages(&item, sess.show_raw_agent_reasoning()),
|
||||||
};
|
};
|
||||||
for msg in msgs {
|
for msg in msgs {
|
||||||
let event = Event {
|
sess.send_event(&turn_context, msg).await;
|
||||||
id: sub_id.to_string(),
|
|
||||||
msg,
|
|
||||||
};
|
|
||||||
sess.send_event(event).await;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ResponseItem::FunctionCallOutput { .. } | ResponseItem::CustomToolCallOutput { .. } => {
|
ResponseItem::FunctionCallOutput { .. } | ResponseItem::CustomToolCallOutput { .. } => {
|
||||||
@@ -2255,16 +2233,13 @@ fn convert_call_tool_result_to_function_call_output_payload(
|
|||||||
/// and records a developer message with the review output.
|
/// and records a developer message with the review output.
|
||||||
pub(crate) async fn exit_review_mode(
|
pub(crate) async fn exit_review_mode(
|
||||||
session: Arc<Session>,
|
session: Arc<Session>,
|
||||||
task_sub_id: String,
|
turn_context: Arc<TurnContext>,
|
||||||
review_output: Option<ReviewOutputEvent>,
|
review_output: Option<ReviewOutputEvent>,
|
||||||
) {
|
) {
|
||||||
let event = Event {
|
let event = EventMsg::ExitedReviewMode(ExitedReviewModeEvent {
|
||||||
id: task_sub_id,
|
review_output: review_output.clone(),
|
||||||
msg: EventMsg::ExitedReviewMode(ExitedReviewModeEvent {
|
});
|
||||||
review_output: review_output.clone(),
|
session.send_event(turn_context.as_ref(), event).await;
|
||||||
}),
|
|
||||||
};
|
|
||||||
session.send_event(event).await;
|
|
||||||
|
|
||||||
let mut user_message = String::new();
|
let mut user_message = String::new();
|
||||||
if let Some(out) = review_output {
|
if let Some(out) = review_output {
|
||||||
@@ -2676,6 +2651,7 @@ mod tests {
|
|||||||
&session_configuration,
|
&session_configuration,
|
||||||
conversation_id,
|
conversation_id,
|
||||||
tx_event.clone(),
|
tx_event.clone(),
|
||||||
|
"turn_id".to_string(),
|
||||||
);
|
);
|
||||||
|
|
||||||
let session = Session {
|
let session = Session {
|
||||||
@@ -2744,6 +2720,7 @@ mod tests {
|
|||||||
&session_configuration,
|
&session_configuration,
|
||||||
conversation_id,
|
conversation_id,
|
||||||
tx_event.clone(),
|
tx_event.clone(),
|
||||||
|
"turn_id".to_string(),
|
||||||
));
|
));
|
||||||
|
|
||||||
let session = Arc::new(Session {
|
let session = Arc::new(Session {
|
||||||
@@ -2774,7 +2751,6 @@ mod tests {
|
|||||||
self: Arc<Self>,
|
self: Arc<Self>,
|
||||||
_session: Arc<SessionTaskContext>,
|
_session: Arc<SessionTaskContext>,
|
||||||
_ctx: Arc<TurnContext>,
|
_ctx: Arc<TurnContext>,
|
||||||
_sub_id: String,
|
|
||||||
_input: Vec<UserInput>,
|
_input: Vec<UserInput>,
|
||||||
cancellation_token: CancellationToken,
|
cancellation_token: CancellationToken,
|
||||||
) -> Option<String> {
|
) -> Option<String> {
|
||||||
@@ -2787,9 +2763,9 @@ mod tests {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn abort(&self, session: Arc<SessionTaskContext>, sub_id: &str) {
|
async fn abort(&self, session: Arc<SessionTaskContext>, ctx: Arc<TurnContext>) {
|
||||||
if let TaskKind::Review = self.kind {
|
if let TaskKind::Review = self.kind {
|
||||||
exit_review_mode(session.clone_session(), sub_id.to_string(), None).await;
|
exit_review_mode(session.clone_session(), ctx, None).await;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -2798,13 +2774,11 @@ mod tests {
|
|||||||
#[test_log::test]
|
#[test_log::test]
|
||||||
async fn abort_regular_task_emits_turn_aborted_only() {
|
async fn abort_regular_task_emits_turn_aborted_only() {
|
||||||
let (sess, tc, rx) = make_session_and_context_with_rx();
|
let (sess, tc, rx) = make_session_and_context_with_rx();
|
||||||
let sub_id = "sub-regular".to_string();
|
|
||||||
let input = vec![UserInput::Text {
|
let input = vec![UserInput::Text {
|
||||||
text: "hello".to_string(),
|
text: "hello".to_string(),
|
||||||
}];
|
}];
|
||||||
sess.spawn_task(
|
sess.spawn_task(
|
||||||
Arc::clone(&tc),
|
Arc::clone(&tc),
|
||||||
sub_id.clone(),
|
|
||||||
input,
|
input,
|
||||||
NeverEndingTask {
|
NeverEndingTask {
|
||||||
kind: TaskKind::Regular,
|
kind: TaskKind::Regular,
|
||||||
@@ -2829,13 +2803,11 @@ mod tests {
|
|||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn abort_gracefuly_emits_turn_aborted_only() {
|
async fn abort_gracefuly_emits_turn_aborted_only() {
|
||||||
let (sess, tc, rx) = make_session_and_context_with_rx();
|
let (sess, tc, rx) = make_session_and_context_with_rx();
|
||||||
let sub_id = "sub-regular".to_string();
|
|
||||||
let input = vec![UserInput::Text {
|
let input = vec![UserInput::Text {
|
||||||
text: "hello".to_string(),
|
text: "hello".to_string(),
|
||||||
}];
|
}];
|
||||||
sess.spawn_task(
|
sess.spawn_task(
|
||||||
Arc::clone(&tc),
|
Arc::clone(&tc),
|
||||||
sub_id.clone(),
|
|
||||||
input,
|
input,
|
||||||
NeverEndingTask {
|
NeverEndingTask {
|
||||||
kind: TaskKind::Regular,
|
kind: TaskKind::Regular,
|
||||||
@@ -2857,13 +2829,11 @@ mod tests {
|
|||||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||||
async fn abort_review_task_emits_exited_then_aborted_and_records_history() {
|
async fn abort_review_task_emits_exited_then_aborted_and_records_history() {
|
||||||
let (sess, tc, rx) = make_session_and_context_with_rx();
|
let (sess, tc, rx) = make_session_and_context_with_rx();
|
||||||
let sub_id = "sub-review".to_string();
|
|
||||||
let input = vec![UserInput::Text {
|
let input = vec![UserInput::Text {
|
||||||
text: "start review".to_string(),
|
text: "start review".to_string(),
|
||||||
}];
|
}];
|
||||||
sess.spawn_task(
|
sess.spawn_task(
|
||||||
Arc::clone(&tc),
|
Arc::clone(&tc),
|
||||||
sub_id.clone(),
|
|
||||||
input,
|
input,
|
||||||
NeverEndingTask {
|
NeverEndingTask {
|
||||||
kind: TaskKind::Review,
|
kind: TaskKind::Review,
|
||||||
@@ -2935,7 +2905,6 @@ mod tests {
|
|||||||
Arc::clone(&session),
|
Arc::clone(&session),
|
||||||
Arc::clone(&turn_context),
|
Arc::clone(&turn_context),
|
||||||
tracker,
|
tracker,
|
||||||
"sub-id".to_string(),
|
|
||||||
call,
|
call,
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
@@ -3095,7 +3064,6 @@ mod tests {
|
|||||||
let turn_diff_tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::new()));
|
let turn_diff_tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::new()));
|
||||||
|
|
||||||
let tool_name = "shell";
|
let tool_name = "shell";
|
||||||
let sub_id = "test-sub".to_string();
|
|
||||||
let call_id = "test-call".to_string();
|
let call_id = "test-call".to_string();
|
||||||
|
|
||||||
let resp = handle_container_exec_with_params(
|
let resp = handle_container_exec_with_params(
|
||||||
@@ -3104,7 +3072,6 @@ mod tests {
|
|||||||
Arc::clone(&session),
|
Arc::clone(&session),
|
||||||
Arc::clone(&turn_context),
|
Arc::clone(&turn_context),
|
||||||
Arc::clone(&turn_diff_tracker),
|
Arc::clone(&turn_diff_tracker),
|
||||||
sub_id,
|
|
||||||
call_id,
|
call_id,
|
||||||
)
|
)
|
||||||
.await;
|
.await;
|
||||||
@@ -3132,7 +3099,6 @@ mod tests {
|
|||||||
Arc::clone(&session),
|
Arc::clone(&session),
|
||||||
Arc::clone(&turn_context),
|
Arc::clone(&turn_context),
|
||||||
Arc::clone(&turn_diff_tracker),
|
Arc::clone(&turn_diff_tracker),
|
||||||
"test-sub".to_string(),
|
|
||||||
"test-call-2".to_string(),
|
"test-call-2".to_string(),
|
||||||
)
|
)
|
||||||
.await;
|
.await;
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ use crate::error::Result as CodexResult;
|
|||||||
use crate::protocol::AgentMessageEvent;
|
use crate::protocol::AgentMessageEvent;
|
||||||
use crate::protocol::CompactedItem;
|
use crate::protocol::CompactedItem;
|
||||||
use crate::protocol::ErrorEvent;
|
use crate::protocol::ErrorEvent;
|
||||||
use crate::protocol::Event;
|
|
||||||
use crate::protocol::EventMsg;
|
use crate::protocol::EventMsg;
|
||||||
use crate::protocol::InputMessageKind;
|
use crate::protocol::InputMessageKind;
|
||||||
use crate::protocol::TaskStartedEvent;
|
use crate::protocol::TaskStartedEvent;
|
||||||
@@ -40,34 +39,28 @@ pub(crate) async fn run_inline_auto_compact_task(
|
|||||||
sess: Arc<Session>,
|
sess: Arc<Session>,
|
||||||
turn_context: Arc<TurnContext>,
|
turn_context: Arc<TurnContext>,
|
||||||
) {
|
) {
|
||||||
let sub_id = sess.next_internal_sub_id();
|
|
||||||
let input = vec![UserInput::Text {
|
let input = vec![UserInput::Text {
|
||||||
text: SUMMARIZATION_PROMPT.to_string(),
|
text: SUMMARIZATION_PROMPT.to_string(),
|
||||||
}];
|
}];
|
||||||
run_compact_task_inner(sess, turn_context, sub_id, input).await;
|
run_compact_task_inner(sess, turn_context, input).await;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) async fn run_compact_task(
|
pub(crate) async fn run_compact_task(
|
||||||
sess: Arc<Session>,
|
sess: Arc<Session>,
|
||||||
turn_context: Arc<TurnContext>,
|
turn_context: Arc<TurnContext>,
|
||||||
sub_id: String,
|
|
||||||
input: Vec<UserInput>,
|
input: Vec<UserInput>,
|
||||||
) -> Option<String> {
|
) -> Option<String> {
|
||||||
let start_event = Event {
|
let start_event = EventMsg::TaskStarted(TaskStartedEvent {
|
||||||
id: sub_id.clone(),
|
model_context_window: turn_context.client.get_model_context_window(),
|
||||||
msg: EventMsg::TaskStarted(TaskStartedEvent {
|
});
|
||||||
model_context_window: turn_context.client.get_model_context_window(),
|
sess.send_event(&turn_context, start_event).await;
|
||||||
}),
|
run_compact_task_inner(sess.clone(), turn_context, input).await;
|
||||||
};
|
|
||||||
sess.send_event(start_event).await;
|
|
||||||
run_compact_task_inner(sess.clone(), turn_context, sub_id.clone(), input).await;
|
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn run_compact_task_inner(
|
async fn run_compact_task_inner(
|
||||||
sess: Arc<Session>,
|
sess: Arc<Session>,
|
||||||
turn_context: Arc<TurnContext>,
|
turn_context: Arc<TurnContext>,
|
||||||
sub_id: String,
|
|
||||||
input: Vec<UserInput>,
|
input: Vec<UserInput>,
|
||||||
) {
|
) {
|
||||||
let initial_input_for_turn: ResponseInputItem = ResponseInputItem::from(input);
|
let initial_input_for_turn: ResponseInputItem = ResponseInputItem::from(input);
|
||||||
@@ -94,14 +87,13 @@ async fn run_compact_task_inner(
|
|||||||
input: turn_input.clone(),
|
input: turn_input.clone(),
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
let attempt_result =
|
let attempt_result = drain_to_completed(&sess, turn_context.as_ref(), &prompt).await;
|
||||||
drain_to_completed(&sess, turn_context.as_ref(), &sub_id, &prompt).await;
|
|
||||||
|
|
||||||
match attempt_result {
|
match attempt_result {
|
||||||
Ok(()) => {
|
Ok(()) => {
|
||||||
if truncated_count > 0 {
|
if truncated_count > 0 {
|
||||||
sess.notify_background_event(
|
sess.notify_background_event(
|
||||||
&sub_id,
|
turn_context.as_ref(),
|
||||||
format!(
|
format!(
|
||||||
"Trimmed {truncated_count} older conversation item(s) before compacting so the prompt fits the model context window."
|
"Trimmed {truncated_count} older conversation item(s) before compacting so the prompt fits the model context window."
|
||||||
),
|
),
|
||||||
@@ -120,15 +112,11 @@ async fn run_compact_task_inner(
|
|||||||
retries = 0;
|
retries = 0;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
sess.set_total_tokens_full(&sub_id, turn_context.as_ref())
|
sess.set_total_tokens_full(turn_context.as_ref()).await;
|
||||||
.await;
|
let event = EventMsg::Error(ErrorEvent {
|
||||||
let event = Event {
|
message: e.to_string(),
|
||||||
id: sub_id.clone(),
|
});
|
||||||
msg: EventMsg::Error(ErrorEvent {
|
sess.send_event(&turn_context, event).await;
|
||||||
message: e.to_string(),
|
|
||||||
}),
|
|
||||||
};
|
|
||||||
sess.send_event(event).await;
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
@@ -136,20 +124,17 @@ async fn run_compact_task_inner(
|
|||||||
retries += 1;
|
retries += 1;
|
||||||
let delay = backoff(retries);
|
let delay = backoff(retries);
|
||||||
sess.notify_stream_error(
|
sess.notify_stream_error(
|
||||||
&sub_id,
|
turn_context.as_ref(),
|
||||||
format!("Re-connecting... {retries}/{max_retries}"),
|
format!("Re-connecting... {retries}/{max_retries}"),
|
||||||
)
|
)
|
||||||
.await;
|
.await;
|
||||||
tokio::time::sleep(delay).await;
|
tokio::time::sleep(delay).await;
|
||||||
continue;
|
continue;
|
||||||
} else {
|
} else {
|
||||||
let event = Event {
|
let event = EventMsg::Error(ErrorEvent {
|
||||||
id: sub_id.clone(),
|
message: e.to_string(),
|
||||||
msg: EventMsg::Error(ErrorEvent {
|
});
|
||||||
message: e.to_string(),
|
sess.send_event(&turn_context, event).await;
|
||||||
}),
|
|
||||||
};
|
|
||||||
sess.send_event(event).await;
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -168,13 +153,10 @@ async fn run_compact_task_inner(
|
|||||||
});
|
});
|
||||||
sess.persist_rollout_items(&[rollout_item]).await;
|
sess.persist_rollout_items(&[rollout_item]).await;
|
||||||
|
|
||||||
let event = Event {
|
let event = EventMsg::AgentMessage(AgentMessageEvent {
|
||||||
id: sub_id.clone(),
|
message: "Compact task completed".to_string(),
|
||||||
msg: EventMsg::AgentMessage(AgentMessageEvent {
|
});
|
||||||
message: "Compact task completed".to_string(),
|
sess.send_event(&turn_context, event).await;
|
||||||
}),
|
|
||||||
};
|
|
||||||
sess.send_event(event).await;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn content_items_to_text(content: &[ContentItem]) -> Option<String> {
|
pub fn content_items_to_text(content: &[ContentItem]) -> Option<String> {
|
||||||
@@ -256,7 +238,6 @@ pub(crate) fn build_compacted_history(
|
|||||||
async fn drain_to_completed(
|
async fn drain_to_completed(
|
||||||
sess: &Session,
|
sess: &Session,
|
||||||
turn_context: &TurnContext,
|
turn_context: &TurnContext,
|
||||||
sub_id: &str,
|
|
||||||
prompt: &Prompt,
|
prompt: &Prompt,
|
||||||
) -> CodexResult<()> {
|
) -> CodexResult<()> {
|
||||||
let mut stream = turn_context
|
let mut stream = turn_context
|
||||||
@@ -277,10 +258,10 @@ async fn drain_to_completed(
|
|||||||
sess.record_into_history(std::slice::from_ref(&item)).await;
|
sess.record_into_history(std::slice::from_ref(&item)).await;
|
||||||
}
|
}
|
||||||
Ok(ResponseEvent::RateLimits(snapshot)) => {
|
Ok(ResponseEvent::RateLimits(snapshot)) => {
|
||||||
sess.update_rate_limits(sub_id, snapshot).await;
|
sess.update_rate_limits(turn_context, snapshot).await;
|
||||||
}
|
}
|
||||||
Ok(ResponseEvent::Completed { token_usage, .. }) => {
|
Ok(ResponseEvent::Completed { token_usage, .. }) => {
|
||||||
sess.update_token_usage_info(sub_id, turn_context, token_usage.as_ref())
|
sess.update_token_usage_info(turn_context, token_usage.as_ref())
|
||||||
.await;
|
.await;
|
||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ use std::time::Instant;
|
|||||||
use tracing::error;
|
use tracing::error;
|
||||||
|
|
||||||
use crate::codex::Session;
|
use crate::codex::Session;
|
||||||
use crate::protocol::Event;
|
use crate::codex::TurnContext;
|
||||||
use crate::protocol::EventMsg;
|
use crate::protocol::EventMsg;
|
||||||
use crate::protocol::McpInvocation;
|
use crate::protocol::McpInvocation;
|
||||||
use crate::protocol::McpToolCallBeginEvent;
|
use crate::protocol::McpToolCallBeginEvent;
|
||||||
@@ -15,7 +15,7 @@ use codex_protocol::models::ResponseInputItem;
|
|||||||
/// `McpToolCallBegin` and `McpToolCallEnd` events to the `Session`.
|
/// `McpToolCallBegin` and `McpToolCallEnd` events to the `Session`.
|
||||||
pub(crate) async fn handle_mcp_tool_call(
|
pub(crate) async fn handle_mcp_tool_call(
|
||||||
sess: &Session,
|
sess: &Session,
|
||||||
sub_id: &str,
|
turn_context: &TurnContext,
|
||||||
call_id: String,
|
call_id: String,
|
||||||
server: String,
|
server: String,
|
||||||
tool_name: String,
|
tool_name: String,
|
||||||
@@ -51,7 +51,7 @@ pub(crate) async fn handle_mcp_tool_call(
|
|||||||
call_id: call_id.clone(),
|
call_id: call_id.clone(),
|
||||||
invocation: invocation.clone(),
|
invocation: invocation.clone(),
|
||||||
});
|
});
|
||||||
notify_mcp_tool_call_event(sess, sub_id, tool_call_begin_event).await;
|
notify_mcp_tool_call_event(sess, turn_context, tool_call_begin_event).await;
|
||||||
|
|
||||||
let start = Instant::now();
|
let start = Instant::now();
|
||||||
// Perform the tool call.
|
// Perform the tool call.
|
||||||
@@ -69,15 +69,11 @@ pub(crate) async fn handle_mcp_tool_call(
|
|||||||
result: result.clone(),
|
result: result.clone(),
|
||||||
});
|
});
|
||||||
|
|
||||||
notify_mcp_tool_call_event(sess, sub_id, tool_call_end_event.clone()).await;
|
notify_mcp_tool_call_event(sess, turn_context, tool_call_end_event.clone()).await;
|
||||||
|
|
||||||
ResponseInputItem::McpToolCallOutput { call_id, result }
|
ResponseInputItem::McpToolCallOutput { call_id, result }
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn notify_mcp_tool_call_event(sess: &Session, sub_id: &str, event: EventMsg) {
|
async fn notify_mcp_tool_call_event(sess: &Session, turn_context: &TurnContext, event: EventMsg) {
|
||||||
sess.send_event(Event {
|
sess.send_event(turn_context, event).await;
|
||||||
id: sub_id.to_string(),
|
|
||||||
msg: event,
|
|
||||||
})
|
|
||||||
.await;
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ use tokio_util::task::AbortOnDropHandle;
|
|||||||
use codex_protocol::models::ResponseInputItem;
|
use codex_protocol::models::ResponseInputItem;
|
||||||
use tokio::sync::oneshot;
|
use tokio::sync::oneshot;
|
||||||
|
|
||||||
|
use crate::codex::TurnContext;
|
||||||
use crate::protocol::ReviewDecision;
|
use crate::protocol::ReviewDecision;
|
||||||
use crate::tasks::SessionTask;
|
use crate::tasks::SessionTask;
|
||||||
|
|
||||||
@@ -53,10 +54,12 @@ pub(crate) struct RunningTask {
|
|||||||
pub(crate) task: Arc<dyn SessionTask>,
|
pub(crate) task: Arc<dyn SessionTask>,
|
||||||
pub(crate) cancellation_token: CancellationToken,
|
pub(crate) cancellation_token: CancellationToken,
|
||||||
pub(crate) handle: Arc<AbortOnDropHandle<()>>,
|
pub(crate) handle: Arc<AbortOnDropHandle<()>>,
|
||||||
|
pub(crate) turn_context: Arc<TurnContext>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ActiveTurn {
|
impl ActiveTurn {
|
||||||
pub(crate) fn add_task(&mut self, sub_id: String, task: RunningTask) {
|
pub(crate) fn add_task(&mut self, task: RunningTask) {
|
||||||
|
let sub_id = task.turn_context.sub_id.clone();
|
||||||
self.tasks.insert(sub_id, task);
|
self.tasks.insert(sub_id, task);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -65,8 +68,8 @@ impl ActiveTurn {
|
|||||||
self.tasks.is_empty()
|
self.tasks.is_empty()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn drain_tasks(&mut self) -> IndexMap<String, RunningTask> {
|
pub(crate) fn drain_tasks(&mut self) -> Vec<RunningTask> {
|
||||||
std::mem::take(&mut self.tasks)
|
self.tasks.drain(..).map(|(_, task)| task).collect()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -24,10 +24,9 @@ impl SessionTask for CompactTask {
|
|||||||
self: Arc<Self>,
|
self: Arc<Self>,
|
||||||
session: Arc<SessionTaskContext>,
|
session: Arc<SessionTaskContext>,
|
||||||
ctx: Arc<TurnContext>,
|
ctx: Arc<TurnContext>,
|
||||||
sub_id: String,
|
|
||||||
input: Vec<UserInput>,
|
input: Vec<UserInput>,
|
||||||
_cancellation_token: CancellationToken,
|
_cancellation_token: CancellationToken,
|
||||||
) -> Option<String> {
|
) -> Option<String> {
|
||||||
compact::run_compact_task(session.clone_session(), ctx, sub_id, input).await
|
compact::run_compact_task(session.clone_session(), ctx, input).await
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,7 +15,6 @@ use tracing::warn;
|
|||||||
|
|
||||||
use crate::codex::Session;
|
use crate::codex::Session;
|
||||||
use crate::codex::TurnContext;
|
use crate::codex::TurnContext;
|
||||||
use crate::protocol::Event;
|
|
||||||
use crate::protocol::EventMsg;
|
use crate::protocol::EventMsg;
|
||||||
use crate::protocol::TaskCompleteEvent;
|
use crate::protocol::TaskCompleteEvent;
|
||||||
use crate::protocol::TurnAbortReason;
|
use crate::protocol::TurnAbortReason;
|
||||||
@@ -55,13 +54,12 @@ pub(crate) trait SessionTask: Send + Sync + 'static {
|
|||||||
self: Arc<Self>,
|
self: Arc<Self>,
|
||||||
session: Arc<SessionTaskContext>,
|
session: Arc<SessionTaskContext>,
|
||||||
ctx: Arc<TurnContext>,
|
ctx: Arc<TurnContext>,
|
||||||
sub_id: String,
|
|
||||||
input: Vec<UserInput>,
|
input: Vec<UserInput>,
|
||||||
cancellation_token: CancellationToken,
|
cancellation_token: CancellationToken,
|
||||||
) -> Option<String>;
|
) -> Option<String>;
|
||||||
|
|
||||||
async fn abort(&self, session: Arc<SessionTaskContext>, sub_id: &str) {
|
async fn abort(&self, session: Arc<SessionTaskContext>, ctx: Arc<TurnContext>) {
|
||||||
let _ = (session, sub_id);
|
let _ = (session, ctx);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -69,7 +67,6 @@ impl Session {
|
|||||||
pub async fn spawn_task<T: SessionTask>(
|
pub async fn spawn_task<T: SessionTask>(
|
||||||
self: &Arc<Self>,
|
self: &Arc<Self>,
|
||||||
turn_context: Arc<TurnContext>,
|
turn_context: Arc<TurnContext>,
|
||||||
sub_id: String,
|
|
||||||
input: Vec<UserInput>,
|
input: Vec<UserInput>,
|
||||||
task: T,
|
task: T,
|
||||||
) {
|
) {
|
||||||
@@ -86,14 +83,13 @@ impl Session {
|
|||||||
let session_ctx = Arc::new(SessionTaskContext::new(Arc::clone(self)));
|
let session_ctx = Arc::new(SessionTaskContext::new(Arc::clone(self)));
|
||||||
let ctx = Arc::clone(&turn_context);
|
let ctx = Arc::clone(&turn_context);
|
||||||
let task_for_run = Arc::clone(&task);
|
let task_for_run = Arc::clone(&task);
|
||||||
let sub_clone = sub_id.clone();
|
|
||||||
let task_cancellation_token = cancellation_token.child_token();
|
let task_cancellation_token = cancellation_token.child_token();
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
|
let ctx_for_finish = Arc::clone(&ctx);
|
||||||
let last_agent_message = task_for_run
|
let last_agent_message = task_for_run
|
||||||
.run(
|
.run(
|
||||||
Arc::clone(&session_ctx),
|
Arc::clone(&session_ctx),
|
||||||
ctx,
|
ctx,
|
||||||
sub_clone.clone(),
|
|
||||||
input,
|
input,
|
||||||
task_cancellation_token.child_token(),
|
task_cancellation_token.child_token(),
|
||||||
)
|
)
|
||||||
@@ -102,7 +98,8 @@ impl Session {
|
|||||||
if !task_cancellation_token.is_cancelled() {
|
if !task_cancellation_token.is_cancelled() {
|
||||||
// Emit completion uniformly from spawn site so all tasks share the same lifecycle.
|
// Emit completion uniformly from spawn site so all tasks share the same lifecycle.
|
||||||
let sess = session_ctx.clone_session();
|
let sess = session_ctx.clone_session();
|
||||||
sess.on_task_finished(sub_clone, last_agent_message).await;
|
sess.on_task_finished(ctx_for_finish, last_agent_message)
|
||||||
|
.await;
|
||||||
}
|
}
|
||||||
done_clone.notify_waiters();
|
done_clone.notify_waiters();
|
||||||
})
|
})
|
||||||
@@ -114,60 +111,54 @@ impl Session {
|
|||||||
kind: task_kind,
|
kind: task_kind,
|
||||||
task,
|
task,
|
||||||
cancellation_token,
|
cancellation_token,
|
||||||
|
turn_context: Arc::clone(&turn_context),
|
||||||
};
|
};
|
||||||
self.register_new_active_task(sub_id, running_task).await;
|
self.register_new_active_task(running_task).await;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn abort_all_tasks(self: &Arc<Self>, reason: TurnAbortReason) {
|
pub async fn abort_all_tasks(self: &Arc<Self>, reason: TurnAbortReason) {
|
||||||
for (sub_id, task) in self.take_all_running_tasks().await {
|
for task in self.take_all_running_tasks().await {
|
||||||
self.handle_task_abort(sub_id, task, reason.clone()).await;
|
self.handle_task_abort(task, reason.clone()).await;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn on_task_finished(
|
pub async fn on_task_finished(
|
||||||
self: &Arc<Self>,
|
self: &Arc<Self>,
|
||||||
sub_id: String,
|
turn_context: Arc<TurnContext>,
|
||||||
last_agent_message: Option<String>,
|
last_agent_message: Option<String>,
|
||||||
) {
|
) {
|
||||||
let mut active = self.active_turn.lock().await;
|
let mut active = self.active_turn.lock().await;
|
||||||
if let Some(at) = active.as_mut()
|
if let Some(at) = active.as_mut()
|
||||||
&& at.remove_task(&sub_id)
|
&& at.remove_task(&turn_context.sub_id)
|
||||||
{
|
{
|
||||||
*active = None;
|
*active = None;
|
||||||
}
|
}
|
||||||
drop(active);
|
drop(active);
|
||||||
let event = Event {
|
let event = EventMsg::TaskComplete(TaskCompleteEvent { last_agent_message });
|
||||||
id: sub_id,
|
self.send_event(turn_context.as_ref(), event).await;
|
||||||
msg: EventMsg::TaskComplete(TaskCompleteEvent { last_agent_message }),
|
|
||||||
};
|
|
||||||
self.send_event(event).await;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn register_new_active_task(&self, sub_id: String, task: RunningTask) {
|
async fn register_new_active_task(&self, task: RunningTask) {
|
||||||
let mut active = self.active_turn.lock().await;
|
let mut active = self.active_turn.lock().await;
|
||||||
let mut turn = ActiveTurn::default();
|
let mut turn = ActiveTurn::default();
|
||||||
turn.add_task(sub_id, task);
|
turn.add_task(task);
|
||||||
*active = Some(turn);
|
*active = Some(turn);
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn take_all_running_tasks(&self) -> Vec<(String, RunningTask)> {
|
async fn take_all_running_tasks(&self) -> Vec<RunningTask> {
|
||||||
let mut active = self.active_turn.lock().await;
|
let mut active = self.active_turn.lock().await;
|
||||||
match active.take() {
|
match active.take() {
|
||||||
Some(mut at) => {
|
Some(mut at) => {
|
||||||
at.clear_pending().await;
|
at.clear_pending().await;
|
||||||
let tasks = at.drain_tasks();
|
|
||||||
tasks.into_iter().collect()
|
at.drain_tasks()
|
||||||
}
|
}
|
||||||
None => Vec::new(),
|
None => Vec::new(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn handle_task_abort(
|
async fn handle_task_abort(self: &Arc<Self>, task: RunningTask, reason: TurnAbortReason) {
|
||||||
self: &Arc<Self>,
|
let sub_id = task.turn_context.sub_id.clone();
|
||||||
sub_id: String,
|
|
||||||
task: RunningTask,
|
|
||||||
reason: TurnAbortReason,
|
|
||||||
) {
|
|
||||||
if task.cancellation_token.is_cancelled() {
|
if task.cancellation_token.is_cancelled() {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@@ -187,13 +178,12 @@ impl Session {
|
|||||||
task.handle.abort();
|
task.handle.abort();
|
||||||
|
|
||||||
let session_ctx = Arc::new(SessionTaskContext::new(Arc::clone(self)));
|
let session_ctx = Arc::new(SessionTaskContext::new(Arc::clone(self)));
|
||||||
session_task.abort(session_ctx, &sub_id).await;
|
session_task
|
||||||
|
.abort(session_ctx, Arc::clone(&task.turn_context))
|
||||||
|
.await;
|
||||||
|
|
||||||
let event = Event {
|
let event = EventMsg::TurnAborted(TurnAbortedEvent { reason });
|
||||||
id: sub_id.clone(),
|
self.send_event(task.turn_context.as_ref(), event).await;
|
||||||
msg: EventMsg::TurnAborted(TurnAbortedEvent { reason }),
|
|
||||||
};
|
|
||||||
self.send_event(event).await;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -24,19 +24,10 @@ impl SessionTask for RegularTask {
|
|||||||
self: Arc<Self>,
|
self: Arc<Self>,
|
||||||
session: Arc<SessionTaskContext>,
|
session: Arc<SessionTaskContext>,
|
||||||
ctx: Arc<TurnContext>,
|
ctx: Arc<TurnContext>,
|
||||||
sub_id: String,
|
|
||||||
input: Vec<UserInput>,
|
input: Vec<UserInput>,
|
||||||
cancellation_token: CancellationToken,
|
cancellation_token: CancellationToken,
|
||||||
) -> Option<String> {
|
) -> Option<String> {
|
||||||
let sess = session.clone_session();
|
let sess = session.clone_session();
|
||||||
run_task(
|
run_task(sess, ctx, input, TaskKind::Regular, cancellation_token).await
|
||||||
sess,
|
|
||||||
ctx,
|
|
||||||
sub_id,
|
|
||||||
input,
|
|
||||||
TaskKind::Regular,
|
|
||||||
cancellation_token,
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -25,23 +25,14 @@ impl SessionTask for ReviewTask {
|
|||||||
self: Arc<Self>,
|
self: Arc<Self>,
|
||||||
session: Arc<SessionTaskContext>,
|
session: Arc<SessionTaskContext>,
|
||||||
ctx: Arc<TurnContext>,
|
ctx: Arc<TurnContext>,
|
||||||
sub_id: String,
|
|
||||||
input: Vec<UserInput>,
|
input: Vec<UserInput>,
|
||||||
cancellation_token: CancellationToken,
|
cancellation_token: CancellationToken,
|
||||||
) -> Option<String> {
|
) -> Option<String> {
|
||||||
let sess = session.clone_session();
|
let sess = session.clone_session();
|
||||||
run_task(
|
run_task(sess, ctx, input, TaskKind::Review, cancellation_token).await
|
||||||
sess,
|
|
||||||
ctx,
|
|
||||||
sub_id,
|
|
||||||
input,
|
|
||||||
TaskKind::Review,
|
|
||||||
cancellation_token,
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn abort(&self, session: Arc<SessionTaskContext>, sub_id: &str) {
|
async fn abort(&self, session: Arc<SessionTaskContext>, ctx: Arc<TurnContext>) {
|
||||||
exit_review_mode(session.clone_session(), sub_id.to_string(), None).await;
|
exit_review_mode(session.clone_session(), ctx, None).await;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -24,7 +24,6 @@ pub struct ToolInvocation {
|
|||||||
pub session: Arc<Session>,
|
pub session: Arc<Session>,
|
||||||
pub turn: Arc<TurnContext>,
|
pub turn: Arc<TurnContext>,
|
||||||
pub tracker: SharedTurnDiffTracker,
|
pub tracker: SharedTurnDiffTracker,
|
||||||
pub sub_id: String,
|
|
||||||
pub call_id: String,
|
pub call_id: String,
|
||||||
pub tool_name: String,
|
pub tool_name: String,
|
||||||
pub payload: ToolPayload,
|
pub payload: ToolPayload,
|
||||||
@@ -234,7 +233,7 @@ mod tests {
|
|||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug)]
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
pub(crate) struct ExecCommandContext {
|
pub(crate) struct ExecCommandContext {
|
||||||
pub(crate) sub_id: String,
|
pub(crate) turn: Arc<TurnContext>,
|
||||||
pub(crate) call_id: String,
|
pub(crate) call_id: String,
|
||||||
pub(crate) command_for_display: Vec<String>,
|
pub(crate) command_for_display: Vec<String>,
|
||||||
pub(crate) cwd: PathBuf,
|
pub(crate) cwd: PathBuf,
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
use crate::codex::Session;
|
use crate::codex::Session;
|
||||||
|
use crate::codex::TurnContext;
|
||||||
use crate::exec::ExecToolCallOutput;
|
use crate::exec::ExecToolCallOutput;
|
||||||
use crate::parse_command::parse_command;
|
use crate::parse_command::parse_command;
|
||||||
use crate::protocol::Event;
|
|
||||||
use crate::protocol::EventMsg;
|
use crate::protocol::EventMsg;
|
||||||
use crate::protocol::ExecCommandBeginEvent;
|
use crate::protocol::ExecCommandBeginEvent;
|
||||||
use crate::protocol::ExecCommandEndEvent;
|
use crate::protocol::ExecCommandEndEvent;
|
||||||
@@ -20,7 +20,7 @@ use super::format_exec_output_str;
|
|||||||
#[derive(Clone, Copy)]
|
#[derive(Clone, Copy)]
|
||||||
pub(crate) struct ToolEventCtx<'a> {
|
pub(crate) struct ToolEventCtx<'a> {
|
||||||
pub session: &'a Session,
|
pub session: &'a Session,
|
||||||
pub sub_id: &'a str,
|
pub turn: &'a TurnContext,
|
||||||
pub call_id: &'a str,
|
pub call_id: &'a str,
|
||||||
pub turn_diff_tracker: Option<&'a SharedTurnDiffTracker>,
|
pub turn_diff_tracker: Option<&'a SharedTurnDiffTracker>,
|
||||||
}
|
}
|
||||||
@@ -28,13 +28,13 @@ pub(crate) struct ToolEventCtx<'a> {
|
|||||||
impl<'a> ToolEventCtx<'a> {
|
impl<'a> ToolEventCtx<'a> {
|
||||||
pub fn new(
|
pub fn new(
|
||||||
session: &'a Session,
|
session: &'a Session,
|
||||||
sub_id: &'a str,
|
turn: &'a TurnContext,
|
||||||
call_id: &'a str,
|
call_id: &'a str,
|
||||||
turn_diff_tracker: Option<&'a SharedTurnDiffTracker>,
|
turn_diff_tracker: Option<&'a SharedTurnDiffTracker>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
session,
|
session,
|
||||||
sub_id,
|
turn,
|
||||||
call_id,
|
call_id,
|
||||||
turn_diff_tracker,
|
turn_diff_tracker,
|
||||||
}
|
}
|
||||||
@@ -79,15 +79,15 @@ impl ToolEmitter {
|
|||||||
match (self, stage) {
|
match (self, stage) {
|
||||||
(Self::Shell { command, cwd }, ToolEventStage::Begin) => {
|
(Self::Shell { command, cwd }, ToolEventStage::Begin) => {
|
||||||
ctx.session
|
ctx.session
|
||||||
.send_event(Event {
|
.send_event(
|
||||||
id: ctx.sub_id.to_string(),
|
ctx.turn,
|
||||||
msg: EventMsg::ExecCommandBegin(ExecCommandBeginEvent {
|
EventMsg::ExecCommandBegin(ExecCommandBeginEvent {
|
||||||
call_id: ctx.call_id.to_string(),
|
call_id: ctx.call_id.to_string(),
|
||||||
command: command.clone(),
|
command: command.clone(),
|
||||||
cwd: cwd.clone(),
|
cwd: cwd.clone(),
|
||||||
parsed_cmd: parse_command(command),
|
parsed_cmd: parse_command(command),
|
||||||
}),
|
}),
|
||||||
})
|
)
|
||||||
.await;
|
.await;
|
||||||
}
|
}
|
||||||
(Self::Shell { .. }, ToolEventStage::Success(output)) => {
|
(Self::Shell { .. }, ToolEventStage::Success(output)) => {
|
||||||
@@ -139,14 +139,14 @@ impl ToolEmitter {
|
|||||||
guard.on_patch_begin(changes);
|
guard.on_patch_begin(changes);
|
||||||
}
|
}
|
||||||
ctx.session
|
ctx.session
|
||||||
.send_event(Event {
|
.send_event(
|
||||||
id: ctx.sub_id.to_string(),
|
ctx.turn,
|
||||||
msg: EventMsg::PatchApplyBegin(PatchApplyBeginEvent {
|
EventMsg::PatchApplyBegin(PatchApplyBeginEvent {
|
||||||
call_id: ctx.call_id.to_string(),
|
call_id: ctx.call_id.to_string(),
|
||||||
auto_approved: *auto_approved,
|
auto_approved: *auto_approved,
|
||||||
changes: changes.clone(),
|
changes: changes.clone(),
|
||||||
}),
|
}),
|
||||||
})
|
)
|
||||||
.await;
|
.await;
|
||||||
}
|
}
|
||||||
(Self::ApplyPatch { .. }, ToolEventStage::Success(output)) => {
|
(Self::ApplyPatch { .. }, ToolEventStage::Success(output)) => {
|
||||||
@@ -190,9 +190,9 @@ async fn emit_exec_end(
|
|||||||
formatted_output: String,
|
formatted_output: String,
|
||||||
) {
|
) {
|
||||||
ctx.session
|
ctx.session
|
||||||
.send_event(Event {
|
.send_event(
|
||||||
id: ctx.sub_id.to_string(),
|
ctx.turn,
|
||||||
msg: EventMsg::ExecCommandEnd(ExecCommandEndEvent {
|
EventMsg::ExecCommandEnd(ExecCommandEndEvent {
|
||||||
call_id: ctx.call_id.to_string(),
|
call_id: ctx.call_id.to_string(),
|
||||||
stdout,
|
stdout,
|
||||||
stderr,
|
stderr,
|
||||||
@@ -201,21 +201,21 @@ async fn emit_exec_end(
|
|||||||
duration,
|
duration,
|
||||||
formatted_output,
|
formatted_output,
|
||||||
}),
|
}),
|
||||||
})
|
)
|
||||||
.await;
|
.await;
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn emit_patch_end(ctx: ToolEventCtx<'_>, stdout: String, stderr: String, success: bool) {
|
async fn emit_patch_end(ctx: ToolEventCtx<'_>, stdout: String, stderr: String, success: bool) {
|
||||||
ctx.session
|
ctx.session
|
||||||
.send_event(Event {
|
.send_event(
|
||||||
id: ctx.sub_id.to_string(),
|
ctx.turn,
|
||||||
msg: EventMsg::PatchApplyEnd(PatchApplyEndEvent {
|
EventMsg::PatchApplyEnd(PatchApplyEndEvent {
|
||||||
call_id: ctx.call_id.to_string(),
|
call_id: ctx.call_id.to_string(),
|
||||||
stdout,
|
stdout,
|
||||||
stderr,
|
stderr,
|
||||||
success,
|
success,
|
||||||
}),
|
}),
|
||||||
})
|
)
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
if let Some(tracker) = ctx.turn_diff_tracker {
|
if let Some(tracker) = ctx.turn_diff_tracker {
|
||||||
@@ -225,10 +225,7 @@ async fn emit_patch_end(ctx: ToolEventCtx<'_>, stdout: String, stderr: String, s
|
|||||||
};
|
};
|
||||||
if let Ok(Some(unified_diff)) = unified_diff {
|
if let Ok(Some(unified_diff)) = unified_diff {
|
||||||
ctx.session
|
ctx.session
|
||||||
.send_event(Event {
|
.send_event(ctx.turn, EventMsg::TurnDiff(TurnDiffEvent { unified_diff }))
|
||||||
id: ctx.sub_id.to_string(),
|
|
||||||
msg: EventMsg::TurnDiff(TurnDiffEvent { unified_diff }),
|
|
||||||
})
|
|
||||||
.await;
|
.await;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -42,7 +42,6 @@ impl ToolHandler for ApplyPatchHandler {
|
|||||||
session,
|
session,
|
||||||
turn,
|
turn,
|
||||||
tracker,
|
tracker,
|
||||||
sub_id,
|
|
||||||
call_id,
|
call_id,
|
||||||
tool_name,
|
tool_name,
|
||||||
payload,
|
payload,
|
||||||
@@ -81,7 +80,6 @@ impl ToolHandler for ApplyPatchHandler {
|
|||||||
Arc::clone(&session),
|
Arc::clone(&session),
|
||||||
Arc::clone(&turn),
|
Arc::clone(&turn),
|
||||||
Arc::clone(&tracker),
|
Arc::clone(&tracker),
|
||||||
sub_id.clone(),
|
|
||||||
call_id.clone(),
|
call_id.clone(),
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ impl ToolHandler for McpHandler {
|
|||||||
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
|
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
|
||||||
let ToolInvocation {
|
let ToolInvocation {
|
||||||
session,
|
session,
|
||||||
sub_id,
|
turn,
|
||||||
call_id,
|
call_id,
|
||||||
payload,
|
payload,
|
||||||
..
|
..
|
||||||
@@ -43,7 +43,7 @@ impl ToolHandler for McpHandler {
|
|||||||
|
|
||||||
let response = handle_mcp_tool_call(
|
let response = handle_mcp_tool_call(
|
||||||
session.as_ref(),
|
session.as_ref(),
|
||||||
&sub_id,
|
turn.as_ref(),
|
||||||
call_id.clone(),
|
call_id.clone(),
|
||||||
server,
|
server,
|
||||||
tool,
|
tool,
|
||||||
|
|||||||
@@ -21,8 +21,8 @@ use serde::de::DeserializeOwned;
|
|||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
|
|
||||||
use crate::codex::Session;
|
use crate::codex::Session;
|
||||||
|
use crate::codex::TurnContext;
|
||||||
use crate::function_tool::FunctionCallError;
|
use crate::function_tool::FunctionCallError;
|
||||||
use crate::protocol::Event;
|
|
||||||
use crate::protocol::EventMsg;
|
use crate::protocol::EventMsg;
|
||||||
use crate::protocol::McpInvocation;
|
use crate::protocol::McpInvocation;
|
||||||
use crate::protocol::McpToolCallBeginEvent;
|
use crate::protocol::McpToolCallBeginEvent;
|
||||||
@@ -189,7 +189,7 @@ impl ToolHandler for McpResourceHandler {
|
|||||||
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
|
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
|
||||||
let ToolInvocation {
|
let ToolInvocation {
|
||||||
session,
|
session,
|
||||||
sub_id,
|
turn,
|
||||||
call_id,
|
call_id,
|
||||||
tool_name,
|
tool_name,
|
||||||
payload,
|
payload,
|
||||||
@@ -211,7 +211,7 @@ impl ToolHandler for McpResourceHandler {
|
|||||||
"list_mcp_resources" => {
|
"list_mcp_resources" => {
|
||||||
handle_list_resources(
|
handle_list_resources(
|
||||||
Arc::clone(&session),
|
Arc::clone(&session),
|
||||||
sub_id.clone(),
|
Arc::clone(&turn),
|
||||||
call_id.clone(),
|
call_id.clone(),
|
||||||
arguments_value.clone(),
|
arguments_value.clone(),
|
||||||
)
|
)
|
||||||
@@ -220,14 +220,20 @@ impl ToolHandler for McpResourceHandler {
|
|||||||
"list_mcp_resource_templates" => {
|
"list_mcp_resource_templates" => {
|
||||||
handle_list_resource_templates(
|
handle_list_resource_templates(
|
||||||
Arc::clone(&session),
|
Arc::clone(&session),
|
||||||
sub_id.clone(),
|
Arc::clone(&turn),
|
||||||
call_id.clone(),
|
call_id.clone(),
|
||||||
arguments_value.clone(),
|
arguments_value.clone(),
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
"read_mcp_resource" => {
|
"read_mcp_resource" => {
|
||||||
handle_read_resource(Arc::clone(&session), sub_id, call_id, arguments_value).await
|
handle_read_resource(
|
||||||
|
Arc::clone(&session),
|
||||||
|
Arc::clone(&turn),
|
||||||
|
call_id,
|
||||||
|
arguments_value,
|
||||||
|
)
|
||||||
|
.await
|
||||||
}
|
}
|
||||||
other => Err(FunctionCallError::RespondToModel(format!(
|
other => Err(FunctionCallError::RespondToModel(format!(
|
||||||
"unsupported MCP resource tool: {other}"
|
"unsupported MCP resource tool: {other}"
|
||||||
@@ -238,7 +244,7 @@ impl ToolHandler for McpResourceHandler {
|
|||||||
|
|
||||||
async fn handle_list_resources(
|
async fn handle_list_resources(
|
||||||
session: Arc<Session>,
|
session: Arc<Session>,
|
||||||
sub_id: String,
|
turn: Arc<TurnContext>,
|
||||||
call_id: String,
|
call_id: String,
|
||||||
arguments: Option<Value>,
|
arguments: Option<Value>,
|
||||||
) -> Result<ToolOutput, FunctionCallError> {
|
) -> Result<ToolOutput, FunctionCallError> {
|
||||||
@@ -253,7 +259,7 @@ async fn handle_list_resources(
|
|||||||
arguments: arguments.clone(),
|
arguments: arguments.clone(),
|
||||||
};
|
};
|
||||||
|
|
||||||
emit_tool_call_begin(&session, &sub_id, &call_id, invocation.clone()).await;
|
emit_tool_call_begin(&session, turn.as_ref(), &call_id, invocation.clone()).await;
|
||||||
let start = Instant::now();
|
let start = Instant::now();
|
||||||
|
|
||||||
let payload_result: Result<ListResourcesPayload, FunctionCallError> = async {
|
let payload_result: Result<ListResourcesPayload, FunctionCallError> = async {
|
||||||
@@ -297,7 +303,7 @@ async fn handle_list_resources(
|
|||||||
let duration = start.elapsed();
|
let duration = start.elapsed();
|
||||||
emit_tool_call_end(
|
emit_tool_call_end(
|
||||||
&session,
|
&session,
|
||||||
&sub_id,
|
turn.as_ref(),
|
||||||
&call_id,
|
&call_id,
|
||||||
invocation,
|
invocation,
|
||||||
duration,
|
duration,
|
||||||
@@ -311,7 +317,7 @@ async fn handle_list_resources(
|
|||||||
let message = err.to_string();
|
let message = err.to_string();
|
||||||
emit_tool_call_end(
|
emit_tool_call_end(
|
||||||
&session,
|
&session,
|
||||||
&sub_id,
|
turn.as_ref(),
|
||||||
&call_id,
|
&call_id,
|
||||||
invocation,
|
invocation,
|
||||||
duration,
|
duration,
|
||||||
@@ -326,7 +332,7 @@ async fn handle_list_resources(
|
|||||||
let message = err.to_string();
|
let message = err.to_string();
|
||||||
emit_tool_call_end(
|
emit_tool_call_end(
|
||||||
&session,
|
&session,
|
||||||
&sub_id,
|
turn.as_ref(),
|
||||||
&call_id,
|
&call_id,
|
||||||
invocation,
|
invocation,
|
||||||
duration,
|
duration,
|
||||||
@@ -340,7 +346,7 @@ async fn handle_list_resources(
|
|||||||
|
|
||||||
async fn handle_list_resource_templates(
|
async fn handle_list_resource_templates(
|
||||||
session: Arc<Session>,
|
session: Arc<Session>,
|
||||||
sub_id: String,
|
turn: Arc<TurnContext>,
|
||||||
call_id: String,
|
call_id: String,
|
||||||
arguments: Option<Value>,
|
arguments: Option<Value>,
|
||||||
) -> Result<ToolOutput, FunctionCallError> {
|
) -> Result<ToolOutput, FunctionCallError> {
|
||||||
@@ -355,7 +361,7 @@ async fn handle_list_resource_templates(
|
|||||||
arguments: arguments.clone(),
|
arguments: arguments.clone(),
|
||||||
};
|
};
|
||||||
|
|
||||||
emit_tool_call_begin(&session, &sub_id, &call_id, invocation.clone()).await;
|
emit_tool_call_begin(&session, turn.as_ref(), &call_id, invocation.clone()).await;
|
||||||
let start = Instant::now();
|
let start = Instant::now();
|
||||||
|
|
||||||
let payload_result: Result<ListResourceTemplatesPayload, FunctionCallError> = async {
|
let payload_result: Result<ListResourceTemplatesPayload, FunctionCallError> = async {
|
||||||
@@ -403,7 +409,7 @@ async fn handle_list_resource_templates(
|
|||||||
let duration = start.elapsed();
|
let duration = start.elapsed();
|
||||||
emit_tool_call_end(
|
emit_tool_call_end(
|
||||||
&session,
|
&session,
|
||||||
&sub_id,
|
turn.as_ref(),
|
||||||
&call_id,
|
&call_id,
|
||||||
invocation,
|
invocation,
|
||||||
duration,
|
duration,
|
||||||
@@ -417,7 +423,7 @@ async fn handle_list_resource_templates(
|
|||||||
let message = err.to_string();
|
let message = err.to_string();
|
||||||
emit_tool_call_end(
|
emit_tool_call_end(
|
||||||
&session,
|
&session,
|
||||||
&sub_id,
|
turn.as_ref(),
|
||||||
&call_id,
|
&call_id,
|
||||||
invocation,
|
invocation,
|
||||||
duration,
|
duration,
|
||||||
@@ -432,7 +438,7 @@ async fn handle_list_resource_templates(
|
|||||||
let message = err.to_string();
|
let message = err.to_string();
|
||||||
emit_tool_call_end(
|
emit_tool_call_end(
|
||||||
&session,
|
&session,
|
||||||
&sub_id,
|
turn.as_ref(),
|
||||||
&call_id,
|
&call_id,
|
||||||
invocation,
|
invocation,
|
||||||
duration,
|
duration,
|
||||||
@@ -446,7 +452,7 @@ async fn handle_list_resource_templates(
|
|||||||
|
|
||||||
async fn handle_read_resource(
|
async fn handle_read_resource(
|
||||||
session: Arc<Session>,
|
session: Arc<Session>,
|
||||||
sub_id: String,
|
turn: Arc<TurnContext>,
|
||||||
call_id: String,
|
call_id: String,
|
||||||
arguments: Option<Value>,
|
arguments: Option<Value>,
|
||||||
) -> Result<ToolOutput, FunctionCallError> {
|
) -> Result<ToolOutput, FunctionCallError> {
|
||||||
@@ -461,7 +467,7 @@ async fn handle_read_resource(
|
|||||||
arguments: arguments.clone(),
|
arguments: arguments.clone(),
|
||||||
};
|
};
|
||||||
|
|
||||||
emit_tool_call_begin(&session, &sub_id, &call_id, invocation.clone()).await;
|
emit_tool_call_begin(&session, turn.as_ref(), &call_id, invocation.clone()).await;
|
||||||
let start = Instant::now();
|
let start = Instant::now();
|
||||||
|
|
||||||
let payload_result: Result<ReadResourcePayload, FunctionCallError> = async {
|
let payload_result: Result<ReadResourcePayload, FunctionCallError> = async {
|
||||||
@@ -489,7 +495,7 @@ async fn handle_read_resource(
|
|||||||
let duration = start.elapsed();
|
let duration = start.elapsed();
|
||||||
emit_tool_call_end(
|
emit_tool_call_end(
|
||||||
&session,
|
&session,
|
||||||
&sub_id,
|
turn.as_ref(),
|
||||||
&call_id,
|
&call_id,
|
||||||
invocation,
|
invocation,
|
||||||
duration,
|
duration,
|
||||||
@@ -503,7 +509,7 @@ async fn handle_read_resource(
|
|||||||
let message = err.to_string();
|
let message = err.to_string();
|
||||||
emit_tool_call_end(
|
emit_tool_call_end(
|
||||||
&session,
|
&session,
|
||||||
&sub_id,
|
turn.as_ref(),
|
||||||
&call_id,
|
&call_id,
|
||||||
invocation,
|
invocation,
|
||||||
duration,
|
duration,
|
||||||
@@ -518,7 +524,7 @@ async fn handle_read_resource(
|
|||||||
let message = err.to_string();
|
let message = err.to_string();
|
||||||
emit_tool_call_end(
|
emit_tool_call_end(
|
||||||
&session,
|
&session,
|
||||||
&sub_id,
|
turn.as_ref(),
|
||||||
&call_id,
|
&call_id,
|
||||||
invocation,
|
invocation,
|
||||||
duration,
|
duration,
|
||||||
@@ -544,39 +550,39 @@ fn call_tool_result_from_content(content: &str, success: Option<bool>) -> CallTo
|
|||||||
|
|
||||||
async fn emit_tool_call_begin(
|
async fn emit_tool_call_begin(
|
||||||
session: &Arc<Session>,
|
session: &Arc<Session>,
|
||||||
sub_id: &str,
|
turn: &TurnContext,
|
||||||
call_id: &str,
|
call_id: &str,
|
||||||
invocation: McpInvocation,
|
invocation: McpInvocation,
|
||||||
) {
|
) {
|
||||||
session
|
session
|
||||||
.send_event(Event {
|
.send_event(
|
||||||
id: sub_id.to_string(),
|
turn,
|
||||||
msg: EventMsg::McpToolCallBegin(McpToolCallBeginEvent {
|
EventMsg::McpToolCallBegin(McpToolCallBeginEvent {
|
||||||
call_id: call_id.to_string(),
|
call_id: call_id.to_string(),
|
||||||
invocation,
|
invocation,
|
||||||
}),
|
}),
|
||||||
})
|
)
|
||||||
.await;
|
.await;
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn emit_tool_call_end(
|
async fn emit_tool_call_end(
|
||||||
session: &Arc<Session>,
|
session: &Arc<Session>,
|
||||||
sub_id: &str,
|
turn: &TurnContext,
|
||||||
call_id: &str,
|
call_id: &str,
|
||||||
invocation: McpInvocation,
|
invocation: McpInvocation,
|
||||||
duration: Duration,
|
duration: Duration,
|
||||||
result: Result<CallToolResult, String>,
|
result: Result<CallToolResult, String>,
|
||||||
) {
|
) {
|
||||||
session
|
session
|
||||||
.send_event(Event {
|
.send_event(
|
||||||
id: sub_id.to_string(),
|
turn,
|
||||||
msg: EventMsg::McpToolCallEnd(McpToolCallEndEvent {
|
EventMsg::McpToolCallEnd(McpToolCallEndEvent {
|
||||||
call_id: call_id.to_string(),
|
call_id: call_id.to_string(),
|
||||||
invocation,
|
invocation,
|
||||||
duration,
|
duration,
|
||||||
result,
|
result,
|
||||||
}),
|
}),
|
||||||
})
|
)
|
||||||
.await;
|
.await;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
use crate::client_common::tools::ResponsesApiTool;
|
use crate::client_common::tools::ResponsesApiTool;
|
||||||
use crate::client_common::tools::ToolSpec;
|
use crate::client_common::tools::ToolSpec;
|
||||||
use crate::codex::Session;
|
use crate::codex::Session;
|
||||||
|
use crate::codex::TurnContext;
|
||||||
use crate::function_tool::FunctionCallError;
|
use crate::function_tool::FunctionCallError;
|
||||||
use crate::tools::context::ToolInvocation;
|
use crate::tools::context::ToolInvocation;
|
||||||
use crate::tools::context::ToolOutput;
|
use crate::tools::context::ToolOutput;
|
||||||
@@ -10,7 +11,6 @@ use crate::tools::registry::ToolKind;
|
|||||||
use crate::tools::spec::JsonSchema;
|
use crate::tools::spec::JsonSchema;
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use codex_protocol::plan_tool::UpdatePlanArgs;
|
use codex_protocol::plan_tool::UpdatePlanArgs;
|
||||||
use codex_protocol::protocol::Event;
|
|
||||||
use codex_protocol::protocol::EventMsg;
|
use codex_protocol::protocol::EventMsg;
|
||||||
use std::collections::BTreeMap;
|
use std::collections::BTreeMap;
|
||||||
use std::sync::LazyLock;
|
use std::sync::LazyLock;
|
||||||
@@ -68,7 +68,7 @@ impl ToolHandler for PlanHandler {
|
|||||||
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
|
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
|
||||||
let ToolInvocation {
|
let ToolInvocation {
|
||||||
session,
|
session,
|
||||||
sub_id,
|
turn,
|
||||||
call_id,
|
call_id,
|
||||||
payload,
|
payload,
|
||||||
..
|
..
|
||||||
@@ -84,7 +84,7 @@ impl ToolHandler for PlanHandler {
|
|||||||
};
|
};
|
||||||
|
|
||||||
let content =
|
let content =
|
||||||
handle_update_plan(session.as_ref(), arguments, sub_id.clone(), call_id).await?;
|
handle_update_plan(session.as_ref(), turn.as_ref(), arguments, call_id).await?;
|
||||||
|
|
||||||
Ok(ToolOutput::Function {
|
Ok(ToolOutput::Function {
|
||||||
content,
|
content,
|
||||||
@@ -98,16 +98,13 @@ impl ToolHandler for PlanHandler {
|
|||||||
/// than forcing it to come up and document a plan (TBD how that affects performance).
|
/// than forcing it to come up and document a plan (TBD how that affects performance).
|
||||||
pub(crate) async fn handle_update_plan(
|
pub(crate) async fn handle_update_plan(
|
||||||
session: &Session,
|
session: &Session,
|
||||||
|
turn_context: &TurnContext,
|
||||||
arguments: String,
|
arguments: String,
|
||||||
sub_id: String,
|
|
||||||
_call_id: String,
|
_call_id: String,
|
||||||
) -> Result<String, FunctionCallError> {
|
) -> Result<String, FunctionCallError> {
|
||||||
let args = parse_update_plan_arguments(&arguments)?;
|
let args = parse_update_plan_arguments(&arguments)?;
|
||||||
session
|
session
|
||||||
.send_event(Event {
|
.send_event(turn_context, EventMsg::PlanUpdate(args))
|
||||||
id: sub_id.to_string(),
|
|
||||||
msg: EventMsg::PlanUpdate(args),
|
|
||||||
})
|
|
||||||
.await;
|
.await;
|
||||||
Ok("Plan updated".to_string())
|
Ok("Plan updated".to_string())
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -47,7 +47,6 @@ impl ToolHandler for ShellHandler {
|
|||||||
session,
|
session,
|
||||||
turn,
|
turn,
|
||||||
tracker,
|
tracker,
|
||||||
sub_id,
|
|
||||||
call_id,
|
call_id,
|
||||||
tool_name,
|
tool_name,
|
||||||
payload,
|
payload,
|
||||||
@@ -68,7 +67,6 @@ impl ToolHandler for ShellHandler {
|
|||||||
Arc::clone(&session),
|
Arc::clone(&session),
|
||||||
Arc::clone(&turn),
|
Arc::clone(&turn),
|
||||||
Arc::clone(&tracker),
|
Arc::clone(&tracker),
|
||||||
sub_id.clone(),
|
|
||||||
call_id.clone(),
|
call_id.clone(),
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
@@ -85,7 +83,6 @@ impl ToolHandler for ShellHandler {
|
|||||||
Arc::clone(&session),
|
Arc::clone(&session),
|
||||||
Arc::clone(&turn),
|
Arc::clone(&turn),
|
||||||
Arc::clone(&tracker),
|
Arc::clone(&tracker),
|
||||||
sub_id.clone(),
|
|
||||||
call_id.clone(),
|
call_id.clone(),
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
|
|||||||
@@ -37,7 +37,6 @@ impl ToolHandler for UnifiedExecHandler {
|
|||||||
let ToolInvocation {
|
let ToolInvocation {
|
||||||
session,
|
session,
|
||||||
turn,
|
turn,
|
||||||
sub_id,
|
|
||||||
call_id,
|
call_id,
|
||||||
tool_name: _tool_name,
|
tool_name: _tool_name,
|
||||||
payload,
|
payload,
|
||||||
@@ -91,7 +90,6 @@ impl ToolHandler for UnifiedExecHandler {
|
|||||||
crate::unified_exec::UnifiedExecContext {
|
crate::unified_exec::UnifiedExecContext {
|
||||||
session: &session,
|
session: &session,
|
||||||
turn: turn.as_ref(),
|
turn: turn.as_ref(),
|
||||||
sub_id: &sub_id,
|
|
||||||
call_id: &call_id,
|
call_id: &call_id,
|
||||||
session_id: parsed_session_id,
|
session_id: parsed_session_id,
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ use serde::Deserialize;
|
|||||||
use tokio::fs;
|
use tokio::fs;
|
||||||
|
|
||||||
use crate::function_tool::FunctionCallError;
|
use crate::function_tool::FunctionCallError;
|
||||||
use crate::protocol::Event;
|
|
||||||
use crate::protocol::EventMsg;
|
use crate::protocol::EventMsg;
|
||||||
use crate::protocol::ViewImageToolCallEvent;
|
use crate::protocol::ViewImageToolCallEvent;
|
||||||
use crate::tools::context::ToolInvocation;
|
use crate::tools::context::ToolInvocation;
|
||||||
@@ -31,7 +30,6 @@ impl ToolHandler for ViewImageHandler {
|
|||||||
session,
|
session,
|
||||||
turn,
|
turn,
|
||||||
payload,
|
payload,
|
||||||
sub_id,
|
|
||||||
call_id,
|
call_id,
|
||||||
..
|
..
|
||||||
} = invocation;
|
} = invocation;
|
||||||
@@ -76,13 +74,13 @@ impl ToolHandler for ViewImageHandler {
|
|||||||
})?;
|
})?;
|
||||||
|
|
||||||
session
|
session
|
||||||
.send_event(Event {
|
.send_event(
|
||||||
id: sub_id.to_string(),
|
turn.as_ref(),
|
||||||
msg: EventMsg::ViewImageToolCall(ViewImageToolCallEvent {
|
EventMsg::ViewImageToolCall(ViewImageToolCallEvent {
|
||||||
call_id,
|
call_id,
|
||||||
path: event_path,
|
path: event_path,
|
||||||
}),
|
}),
|
||||||
})
|
)
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
Ok(ToolOutput::Function {
|
Ok(ToolOutput::Function {
|
||||||
|
|||||||
@@ -61,7 +61,6 @@ pub(crate) async fn handle_container_exec_with_params(
|
|||||||
sess: Arc<Session>,
|
sess: Arc<Session>,
|
||||||
turn_context: Arc<TurnContext>,
|
turn_context: Arc<TurnContext>,
|
||||||
turn_diff_tracker: SharedTurnDiffTracker,
|
turn_diff_tracker: SharedTurnDiffTracker,
|
||||||
sub_id: String,
|
|
||||||
call_id: String,
|
call_id: String,
|
||||||
) -> Result<String, FunctionCallError> {
|
) -> Result<String, FunctionCallError> {
|
||||||
let _otel_event_manager = turn_context.client.get_otel_event_manager();
|
let _otel_event_manager = turn_context.client.get_otel_event_manager();
|
||||||
@@ -78,14 +77,8 @@ pub(crate) async fn handle_container_exec_with_params(
|
|||||||
// check if this was a patch, and apply it if so
|
// check if this was a patch, and apply it if so
|
||||||
let apply_patch_exec = match maybe_parse_apply_patch_verified(¶ms.command, ¶ms.cwd) {
|
let apply_patch_exec = match maybe_parse_apply_patch_verified(¶ms.command, ¶ms.cwd) {
|
||||||
MaybeApplyPatchVerified::Body(changes) => {
|
MaybeApplyPatchVerified::Body(changes) => {
|
||||||
match apply_patch::apply_patch(
|
match apply_patch::apply_patch(sess.as_ref(), turn_context.as_ref(), &call_id, changes)
|
||||||
sess.as_ref(),
|
.await
|
||||||
turn_context.as_ref(),
|
|
||||||
&sub_id,
|
|
||||||
&call_id,
|
|
||||||
changes,
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
{
|
{
|
||||||
InternalApplyPatchInvocation::Output(item) => return item,
|
InternalApplyPatchInvocation::Output(item) => return item,
|
||||||
InternalApplyPatchInvocation::DelegateToExec(apply_patch_exec) => {
|
InternalApplyPatchInvocation::DelegateToExec(apply_patch_exec) => {
|
||||||
@@ -122,7 +115,7 @@ pub(crate) async fn handle_container_exec_with_params(
|
|||||||
),
|
),
|
||||||
};
|
};
|
||||||
|
|
||||||
let event_ctx = ToolEventCtx::new(sess.as_ref(), &sub_id, &call_id, diff_opt);
|
let event_ctx = ToolEventCtx::new(sess.as_ref(), turn_context.as_ref(), &call_id, diff_opt);
|
||||||
event_emitter.emit(event_ctx, ToolEventStage::Begin).await;
|
event_emitter.emit(event_ctx, ToolEventStage::Begin).await;
|
||||||
|
|
||||||
// Build runtime contexts only when needed (shell/apply_patch below).
|
// Build runtime contexts only when needed (shell/apply_patch below).
|
||||||
@@ -141,7 +134,7 @@ pub(crate) async fn handle_container_exec_with_params(
|
|||||||
let mut runtime = ApplyPatchRuntime::new();
|
let mut runtime = ApplyPatchRuntime::new();
|
||||||
let tool_ctx = ToolCtx {
|
let tool_ctx = ToolCtx {
|
||||||
session: sess.as_ref(),
|
session: sess.as_ref(),
|
||||||
sub_id: sub_id.clone(),
|
turn: turn_context.as_ref(),
|
||||||
call_id: call_id.clone(),
|
call_id: call_id.clone(),
|
||||||
tool_name: tool_name.to_string(),
|
tool_name: tool_name.to_string(),
|
||||||
};
|
};
|
||||||
@@ -172,7 +165,7 @@ pub(crate) async fn handle_container_exec_with_params(
|
|||||||
let mut runtime = ShellRuntime::new();
|
let mut runtime = ShellRuntime::new();
|
||||||
let tool_ctx = ToolCtx {
|
let tool_ctx = ToolCtx {
|
||||||
session: sess.as_ref(),
|
session: sess.as_ref(),
|
||||||
sub_id: sub_id.clone(),
|
turn: turn_context.as_ref(),
|
||||||
call_id: call_id.clone(),
|
call_id: call_id.clone(),
|
||||||
tool_name: tool_name.to_string(),
|
tool_name: tool_name.to_string(),
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -53,7 +53,7 @@ impl ToolOrchestrator {
|
|||||||
if needs_initial_approval {
|
if needs_initial_approval {
|
||||||
let approval_ctx = ApprovalCtx {
|
let approval_ctx = ApprovalCtx {
|
||||||
session: tool_ctx.session,
|
session: tool_ctx.session,
|
||||||
sub_id: &tool_ctx.sub_id,
|
turn: turn_ctx,
|
||||||
call_id: &tool_ctx.call_id,
|
call_id: &tool_ctx.call_id,
|
||||||
retry_reason: None,
|
retry_reason: None,
|
||||||
};
|
};
|
||||||
@@ -110,7 +110,7 @@ impl ToolOrchestrator {
|
|||||||
let reason_msg = build_denial_reason_from_output(output.as_ref());
|
let reason_msg = build_denial_reason_from_output(output.as_ref());
|
||||||
let approval_ctx = ApprovalCtx {
|
let approval_ctx = ApprovalCtx {
|
||||||
session: tool_ctx.session,
|
session: tool_ctx.session,
|
||||||
sub_id: &tool_ctx.sub_id,
|
turn: turn_ctx,
|
||||||
call_id: &tool_ctx.call_id,
|
call_id: &tool_ctx.call_id,
|
||||||
retry_reason: Some(reason_msg),
|
retry_reason: Some(reason_msg),
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -18,7 +18,6 @@ pub(crate) struct ToolCallRuntime {
|
|||||||
session: Arc<Session>,
|
session: Arc<Session>,
|
||||||
turn_context: Arc<TurnContext>,
|
turn_context: Arc<TurnContext>,
|
||||||
tracker: SharedTurnDiffTracker,
|
tracker: SharedTurnDiffTracker,
|
||||||
sub_id: String,
|
|
||||||
parallel_execution: Arc<RwLock<()>>,
|
parallel_execution: Arc<RwLock<()>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -28,14 +27,12 @@ impl ToolCallRuntime {
|
|||||||
session: Arc<Session>,
|
session: Arc<Session>,
|
||||||
turn_context: Arc<TurnContext>,
|
turn_context: Arc<TurnContext>,
|
||||||
tracker: SharedTurnDiffTracker,
|
tracker: SharedTurnDiffTracker,
|
||||||
sub_id: String,
|
|
||||||
) -> Self {
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
router,
|
router,
|
||||||
session,
|
session,
|
||||||
turn_context,
|
turn_context,
|
||||||
tracker,
|
tracker,
|
||||||
sub_id,
|
|
||||||
parallel_execution: Arc::new(RwLock::new(())),
|
parallel_execution: Arc::new(RwLock::new(())),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -50,7 +47,6 @@ impl ToolCallRuntime {
|
|||||||
let session = Arc::clone(&self.session);
|
let session = Arc::clone(&self.session);
|
||||||
let turn = Arc::clone(&self.turn_context);
|
let turn = Arc::clone(&self.turn_context);
|
||||||
let tracker = Arc::clone(&self.tracker);
|
let tracker = Arc::clone(&self.tracker);
|
||||||
let sub_id = self.sub_id.clone();
|
|
||||||
let lock = Arc::clone(&self.parallel_execution);
|
let lock = Arc::clone(&self.parallel_execution);
|
||||||
|
|
||||||
let handle: AbortOnDropHandle<Result<ResponseInputItem, FunctionCallError>> =
|
let handle: AbortOnDropHandle<Result<ResponseInputItem, FunctionCallError>> =
|
||||||
@@ -62,7 +58,7 @@ impl ToolCallRuntime {
|
|||||||
};
|
};
|
||||||
|
|
||||||
router
|
router
|
||||||
.dispatch_tool_call(session, turn, tracker, sub_id, call)
|
.dispatch_tool_call(session, turn, tracker, call)
|
||||||
.await
|
.await
|
||||||
}));
|
}));
|
||||||
|
|
||||||
|
|||||||
@@ -134,7 +134,6 @@ impl ToolRouter {
|
|||||||
session: Arc<Session>,
|
session: Arc<Session>,
|
||||||
turn: Arc<TurnContext>,
|
turn: Arc<TurnContext>,
|
||||||
tracker: SharedTurnDiffTracker,
|
tracker: SharedTurnDiffTracker,
|
||||||
sub_id: String,
|
|
||||||
call: ToolCall,
|
call: ToolCall,
|
||||||
) -> Result<ResponseInputItem, FunctionCallError> {
|
) -> Result<ResponseInputItem, FunctionCallError> {
|
||||||
let ToolCall {
|
let ToolCall {
|
||||||
@@ -149,7 +148,6 @@ impl ToolRouter {
|
|||||||
session,
|
session,
|
||||||
turn,
|
turn,
|
||||||
tracker,
|
tracker,
|
||||||
sub_id,
|
|
||||||
call_id,
|
call_id,
|
||||||
tool_name,
|
tool_name,
|
||||||
payload,
|
payload,
|
||||||
|
|||||||
@@ -68,7 +68,7 @@ impl ApplyPatchRuntime {
|
|||||||
|
|
||||||
fn stdout_stream(ctx: &ToolCtx<'_>) -> Option<crate::exec::StdoutStream> {
|
fn stdout_stream(ctx: &ToolCtx<'_>) -> Option<crate::exec::StdoutStream> {
|
||||||
Some(crate::exec::StdoutStream {
|
Some(crate::exec::StdoutStream {
|
||||||
sub_id: ctx.sub_id.clone(),
|
sub_id: ctx.turn.sub_id.clone(),
|
||||||
call_id: ctx.call_id.clone(),
|
call_id: ctx.call_id.clone(),
|
||||||
tx_event: ctx.session.get_tx_event(),
|
tx_event: ctx.session.get_tx_event(),
|
||||||
})
|
})
|
||||||
@@ -101,7 +101,7 @@ impl Approvable<ApplyPatchRequest> for ApplyPatchRuntime {
|
|||||||
) -> BoxFuture<'a, ReviewDecision> {
|
) -> BoxFuture<'a, ReviewDecision> {
|
||||||
let key = self.approval_key(req);
|
let key = self.approval_key(req);
|
||||||
let session = ctx.session;
|
let session = ctx.session;
|
||||||
let sub_id = ctx.sub_id.to_string();
|
let turn = ctx.turn;
|
||||||
let call_id = ctx.call_id.to_string();
|
let call_id = ctx.call_id.to_string();
|
||||||
let cwd = req.cwd.clone();
|
let cwd = req.cwd.clone();
|
||||||
let retry_reason = ctx.retry_reason.clone();
|
let retry_reason = ctx.retry_reason.clone();
|
||||||
@@ -111,7 +111,7 @@ impl Approvable<ApplyPatchRequest> for ApplyPatchRuntime {
|
|||||||
if let Some(reason) = retry_reason {
|
if let Some(reason) = retry_reason {
|
||||||
session
|
session
|
||||||
.request_command_approval(
|
.request_command_approval(
|
||||||
sub_id,
|
turn,
|
||||||
call_id,
|
call_id,
|
||||||
vec!["apply_patch".to_string()],
|
vec!["apply_patch".to_string()],
|
||||||
cwd,
|
cwd,
|
||||||
|
|||||||
@@ -51,7 +51,7 @@ impl ShellRuntime {
|
|||||||
|
|
||||||
fn stdout_stream(ctx: &ToolCtx<'_>) -> Option<crate::exec::StdoutStream> {
|
fn stdout_stream(ctx: &ToolCtx<'_>) -> Option<crate::exec::StdoutStream> {
|
||||||
Some(crate::exec::StdoutStream {
|
Some(crate::exec::StdoutStream {
|
||||||
sub_id: ctx.sub_id.clone(),
|
sub_id: ctx.turn.sub_id.clone(),
|
||||||
call_id: ctx.call_id.clone(),
|
call_id: ctx.call_id.clone(),
|
||||||
tx_event: ctx.session.get_tx_event(),
|
tx_event: ctx.session.get_tx_event(),
|
||||||
})
|
})
|
||||||
@@ -91,12 +91,12 @@ impl Approvable<ShellRequest> for ShellRuntime {
|
|||||||
.clone()
|
.clone()
|
||||||
.or_else(|| req.justification.clone());
|
.or_else(|| req.justification.clone());
|
||||||
let session = ctx.session;
|
let session = ctx.session;
|
||||||
let sub_id = ctx.sub_id.to_string();
|
let turn = ctx.turn;
|
||||||
let call_id = ctx.call_id.to_string();
|
let call_id = ctx.call_id.to_string();
|
||||||
Box::pin(async move {
|
Box::pin(async move {
|
||||||
with_cached_approval(&session.services, key, || async move {
|
with_cached_approval(&session.services, key, || async move {
|
||||||
session
|
session
|
||||||
.request_command_approval(sub_id, call_id, command, cwd, reason)
|
.request_command_approval(turn, call_id, command, cwd, reason)
|
||||||
.await
|
.await
|
||||||
})
|
})
|
||||||
.await
|
.await
|
||||||
|
|||||||
@@ -80,7 +80,7 @@ impl Approvable<UnifiedExecRequest> for UnifiedExecRuntime<'_> {
|
|||||||
) -> BoxFuture<'b, ReviewDecision> {
|
) -> BoxFuture<'b, ReviewDecision> {
|
||||||
let key = self.approval_key(req);
|
let key = self.approval_key(req);
|
||||||
let session = ctx.session;
|
let session = ctx.session;
|
||||||
let sub_id = ctx.sub_id.to_string();
|
let turn = ctx.turn;
|
||||||
let call_id = ctx.call_id.to_string();
|
let call_id = ctx.call_id.to_string();
|
||||||
let command = req.command.clone();
|
let command = req.command.clone();
|
||||||
let cwd = req.cwd.clone();
|
let cwd = req.cwd.clone();
|
||||||
@@ -88,7 +88,7 @@ impl Approvable<UnifiedExecRequest> for UnifiedExecRuntime<'_> {
|
|||||||
Box::pin(async move {
|
Box::pin(async move {
|
||||||
with_cached_approval(&session.services, key, || async move {
|
with_cached_approval(&session.services, key, || async move {
|
||||||
session
|
session
|
||||||
.request_command_approval(sub_id, call_id, command, cwd, reason)
|
.request_command_approval(turn, call_id, command, cwd, reason)
|
||||||
.await
|
.await
|
||||||
})
|
})
|
||||||
.await
|
.await
|
||||||
|
|||||||
@@ -5,6 +5,7 @@
|
|||||||
//! and helpers (`Sandboxable`, `ToolRuntime`, `SandboxAttempt`, etc.).
|
//! and helpers (`Sandboxable`, `ToolRuntime`, `SandboxAttempt`, etc.).
|
||||||
|
|
||||||
use crate::codex::Session;
|
use crate::codex::Session;
|
||||||
|
use crate::codex::TurnContext;
|
||||||
use crate::error::CodexErr;
|
use crate::error::CodexErr;
|
||||||
use crate::protocol::SandboxPolicy;
|
use crate::protocol::SandboxPolicy;
|
||||||
use crate::sandboxing::CommandSpec;
|
use crate::sandboxing::CommandSpec;
|
||||||
@@ -77,7 +78,7 @@ where
|
|||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub(crate) struct ApprovalCtx<'a> {
|
pub(crate) struct ApprovalCtx<'a> {
|
||||||
pub session: &'a Session,
|
pub session: &'a Session,
|
||||||
pub sub_id: &'a str,
|
pub turn: &'a TurnContext,
|
||||||
pub call_id: &'a str,
|
pub call_id: &'a str,
|
||||||
pub retry_reason: Option<String>,
|
pub retry_reason: Option<String>,
|
||||||
}
|
}
|
||||||
@@ -145,7 +146,7 @@ pub(crate) trait Sandboxable {
|
|||||||
|
|
||||||
pub(crate) struct ToolCtx<'a> {
|
pub(crate) struct ToolCtx<'a> {
|
||||||
pub session: &'a Session,
|
pub session: &'a Session,
|
||||||
pub sub_id: String,
|
pub turn: &'a TurnContext,
|
||||||
pub call_id: String,
|
pub call_id: String,
|
||||||
pub tool_name: String,
|
pub tool_name: String,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -43,7 +43,6 @@ const UNIFIED_EXEC_OUTPUT_MAX_BYTES: usize = 128 * 1024; // 128 KiB
|
|||||||
pub(crate) struct UnifiedExecContext<'a> {
|
pub(crate) struct UnifiedExecContext<'a> {
|
||||||
pub session: &'a Session,
|
pub session: &'a Session,
|
||||||
pub turn: &'a TurnContext,
|
pub turn: &'a TurnContext,
|
||||||
pub sub_id: &'a str,
|
|
||||||
pub call_id: &'a str,
|
pub call_id: &'a str,
|
||||||
pub session_id: Option<i32>,
|
pub session_id: Option<i32>,
|
||||||
}
|
}
|
||||||
@@ -110,7 +109,6 @@ mod tests {
|
|||||||
UnifiedExecContext {
|
UnifiedExecContext {
|
||||||
session,
|
session,
|
||||||
turn: turn.as_ref(),
|
turn: turn.as_ref(),
|
||||||
sub_id: "sub",
|
|
||||||
call_id: "call",
|
call_id: "call",
|
||||||
session_id,
|
session_id,
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -113,7 +113,7 @@ impl UnifiedExecSessionManager {
|
|||||||
);
|
);
|
||||||
let tool_ctx = ToolCtx {
|
let tool_ctx = ToolCtx {
|
||||||
session: context.session,
|
session: context.session,
|
||||||
sub_id: context.sub_id.to_string(),
|
turn: context.turn,
|
||||||
call_id: context.call_id.to_string(),
|
call_id: context.call_id.to_string(),
|
||||||
tool_name: "unified_exec".to_string(),
|
tool_name: "unified_exec".to_string(),
|
||||||
};
|
};
|
||||||
|
|||||||
Reference in New Issue
Block a user