diff --git a/codex-rs/core/src/apply_patch.rs b/codex-rs/core/src/apply_patch.rs index 5b6728ad..a37b60c0 100644 --- a/codex-rs/core/src/apply_patch.rs +++ b/codex-rs/core/src/apply_patch.rs @@ -36,7 +36,6 @@ pub(crate) struct ApplyPatchExec { pub(crate) async fn apply_patch( sess: &Session, turn_context: &TurnContext, - sub_id: &str, call_id: &str, action: ApplyPatchAction, ) -> InternalApplyPatchInvocation { @@ -62,7 +61,7 @@ pub(crate) async fn apply_patch( // that similar patches can be auto-approved in the future during // this session. let rx_approve = sess - .request_patch_approval(sub_id.to_owned(), call_id.to_owned(), &action, None, None) + .request_patch_approval(turn_context, call_id.to_owned(), &action, None, None) .await; match rx_approve.await.unwrap_or_default() { ReviewDecision::Approved | ReviewDecision::ApprovedForSession => { diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index af1d9ef0..3aee5784 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -255,6 +255,7 @@ pub(crate) struct Session { /// The context needed for a single turn of the conversation. #[derive(Debug)] pub(crate) struct TurnContext { + pub(crate) sub_id: String, pub(crate) client: ModelClient, /// The session's current working directory. All relative paths provided by /// the model as well as sandbox policies are resolved against this path @@ -359,6 +360,7 @@ impl Session { session_configuration: &SessionConfiguration, conversation_id: ConversationId, tx_event: Sender, + sub_id: String, ) -> TurnContext { let config = session_configuration.original_config_do_not_use.clone(); let model_family = find_family_for_model(&session_configuration.model) @@ -392,7 +394,10 @@ impl Session { features: &config.features, }); + let item_collector = ItemCollector::new(tx_event, conversation_id, sub_id.clone()); + TurnContext { + sub_id, client, cwd: session_configuration.cwd.clone(), base_instructions: session_configuration.base_instructions.clone(), @@ -404,7 +409,7 @@ impl Session { is_review_mode: false, final_output_json_schema: None, codex_linux_sandbox_exe: config.codex_linux_sandbox_exe.clone(), - item_collector: ItemCollector::new(tx_event, conversation_id, "turn_id".to_string()), + item_collector, } } @@ -587,7 +592,7 @@ impl Session { }) .chain(post_session_configured_error_events.into_iter()); for event in events { - sess.send_event(event).await; + sess.send_event_raw(event).await; } Ok(sess) @@ -638,6 +643,15 @@ impl Session { } pub(crate) async fn new_turn(&self, updates: SessionSettingsUpdate) -> Arc { + let sub_id = self.next_internal_sub_id(); + self.new_turn_with_sub_id(sub_id, updates).await + } + + pub(crate) async fn new_turn_with_sub_id( + &self, + sub_id: String, + updates: SessionSettingsUpdate, + ) -> Arc { let session_configuration = { let mut state = self.state.lock().await; let session_configuration = state.session_configuration.clone().apply(&updates); @@ -652,6 +666,7 @@ impl Session { &session_configuration, self.conversation_id, self.get_tx_event(), + sub_id, ); if let Some(final_schema) = updates.final_output_json_schema { turn_context.final_output_json_schema = final_schema; @@ -678,7 +693,15 @@ impl Session { } /// Persist the event to rollout and send it to clients. - pub(crate) async fn send_event(&self, event: Event) { + pub(crate) async fn send_event(&self, turn_context: &TurnContext, msg: EventMsg) { + let event = Event { + id: turn_context.sub_id.clone(), + msg, + }; + self.send_event_raw(event).await; + } + + pub(crate) async fn send_event_raw(&self, event: Event) { // Persist the event into rollout (recorder filters as needed) let rollout_items = vec![RolloutItem::EventMsg(event.msg.clone())]; self.persist_rollout_items(&rollout_items).await; @@ -694,12 +717,13 @@ impl Session { /// default `ReviewDecision` (`Denied`). pub async fn request_command_approval( &self, - sub_id: String, + turn_context: &TurnContext, call_id: String, command: Vec, cwd: PathBuf, reason: Option, ) -> ReviewDecision { + let sub_id = turn_context.sub_id.clone(); // Add the tx_approve callback to the map before sending the request. let (tx_approve, rx_approve) = oneshot::channel(); let event_id = sub_id.clone(); @@ -718,28 +742,26 @@ impl Session { } let parsed_cmd = parse_command(&command); - let event = Event { - id: event_id, - msg: EventMsg::ExecApprovalRequest(ExecApprovalRequestEvent { - call_id, - command, - cwd, - reason, - parsed_cmd, - }), - }; - self.send_event(event).await; + let event = EventMsg::ExecApprovalRequest(ExecApprovalRequestEvent { + call_id, + command, + cwd, + reason, + parsed_cmd, + }); + self.send_event(turn_context, event).await; rx_approve.await.unwrap_or_default() } pub async fn request_patch_approval( &self, - sub_id: String, + turn_context: &TurnContext, call_id: String, action: &ApplyPatchAction, reason: Option, grant_root: Option, ) -> oneshot::Receiver { + let sub_id = turn_context.sub_id.clone(); // Add the tx_approve callback to the map before sending the request. let (tx_approve, rx_approve) = oneshot::channel(); let event_id = sub_id.clone(); @@ -757,16 +779,13 @@ impl Session { warn!("Overwriting existing pending approval for sub_id: {event_id}"); } - let event = Event { - id: event_id, - msg: EventMsg::ApplyPatchApprovalRequest(ApplyPatchApprovalRequestEvent { - call_id, - changes: convert_apply_patch_to_protocol(action), - reason, - grant_root, - }), - }; - self.send_event(event).await; + let event = EventMsg::ApplyPatchApprovalRequest(ApplyPatchApprovalRequestEvent { + call_id, + changes: convert_apply_patch_to_protocol(action), + reason, + grant_root, + }); + self.send_event(turn_context, event).await; rx_approve } @@ -878,7 +897,6 @@ impl Session { async fn update_token_usage_info( &self, - sub_id: &str, turn_context: &TurnContext, token_usage: Option<&TokenUsage>, ) { @@ -891,37 +909,38 @@ impl Session { ); } } - self.send_token_count_event(sub_id).await; + self.send_token_count_event(turn_context).await; } - async fn update_rate_limits(&self, sub_id: &str, new_rate_limits: RateLimitSnapshot) { + async fn update_rate_limits( + &self, + turn_context: &TurnContext, + new_rate_limits: RateLimitSnapshot, + ) { { let mut state = self.state.lock().await; state.set_rate_limits(new_rate_limits); } - self.send_token_count_event(sub_id).await; + self.send_token_count_event(turn_context).await; } - async fn send_token_count_event(&self, sub_id: &str) { + async fn send_token_count_event(&self, turn_context: &TurnContext) { let (info, rate_limits) = { let state = self.state.lock().await; state.token_info_and_rate_limits() }; - let event = Event { - id: sub_id.to_string(), - msg: EventMsg::TokenCount(TokenCountEvent { info, rate_limits }), - }; - self.send_event(event).await; + let event = EventMsg::TokenCount(TokenCountEvent { info, rate_limits }); + self.send_event(turn_context, event).await; } - async fn set_total_tokens_full(&self, sub_id: &str, turn_context: &TurnContext) { + async fn set_total_tokens_full(&self, turn_context: &TurnContext) { let context_window = turn_context.client.get_model_context_window(); if let Some(context_window) = context_window { { let mut state = self.state.lock().await; state.set_token_usage_full(context_window); } - self.send_token_count_event(sub_id).await; + self.send_token_count_event(turn_context).await; } } @@ -951,24 +970,22 @@ impl Session { /// Helper that emits a BackgroundEvent with the given message. This keeps /// the call‑sites terse so adding more diagnostics does not clutter the /// core agent logic. - pub(crate) async fn notify_background_event(&self, sub_id: &str, message: impl Into) { - let event = Event { - id: sub_id.to_string(), - msg: EventMsg::BackgroundEvent(BackgroundEventEvent { - message: message.into(), - }), - }; - self.send_event(event).await; + pub(crate) async fn notify_background_event( + &self, + turn_context: &TurnContext, + message: impl Into, + ) { + let event = EventMsg::BackgroundEvent(BackgroundEventEvent { + message: message.into(), + }); + self.send_event(turn_context, event).await; } - async fn notify_stream_error(&self, sub_id: &str, message: impl Into) { - let event = Event { - id: sub_id.to_string(), - msg: EventMsg::StreamError(StreamErrorEvent { - message: message.into(), - }), - }; - self.send_event(event).await; + async fn notify_stream_error(&self, turn_context: &TurnContext, message: impl Into) { + let event = EventMsg::StreamError(StreamErrorEvent { + message: message.into(), + }); + self.send_event(turn_context, event).await; } /// Build the full turn input by concatenating the current conversation @@ -1129,7 +1146,7 @@ async fn submission_loop(sess: Arc, config: Arc, rx_sub: Receiv Op::UserInput { items } => (items, SessionSettingsUpdate::default()), _ => unreachable!(), }; - let current_context = sess.new_turn(updates).await; + let current_context = sess.new_turn_with_sub_id(sub.id.clone(), updates).await; current_context .client .get_otel_event_manager() @@ -1145,11 +1162,7 @@ async fn submission_loop(sess: Arc, config: Arc, rx_sub: Receiv &env_item, sess.show_raw_agent_reasoning(), ) { - let event = Event { - id: sub.id.clone(), - msg, - }; - sess.send_event(event).await; + sess.send_event(¤t_context, msg).await; } } @@ -1158,7 +1171,7 @@ async fn submission_loop(sess: Arc, config: Arc, rx_sub: Receiv .started_completed(TurnItem::UserMessage(UserMessageItem::new(&items))) .await; - sess.spawn_task(Arc::clone(¤t_context), sub.id, items, RegularTask) + sess.spawn_task(Arc::clone(¤t_context), items, RegularTask) .await; previous_context = Some(current_context); } @@ -1216,7 +1229,7 @@ async fn submission_loop(sess: Arc, config: Arc, rx_sub: Receiv ), }; - sess_clone.send_event(event).await; + sess_clone.send_event_raw(event).await; }); } Op::ListMcpTools => { @@ -1249,7 +1262,7 @@ async fn submission_loop(sess: Arc, config: Arc, rx_sub: Receiv }, ), }; - sess.send_event(event).await; + sess.send_event_raw(event).await; } Op::ListCustomPrompts => { let sub_id = sub.id.clone(); @@ -1267,10 +1280,12 @@ async fn submission_loop(sess: Arc, config: Arc, rx_sub: Receiv custom_prompts, }), }; - sess.send_event(event).await; + sess.send_event_raw(event).await; } Op::Compact => { - let turn_context = sess.new_turn(SessionSettingsUpdate::default()).await; + let turn_context = sess + .new_turn_with_sub_id(sub.id.clone(), SessionSettingsUpdate::default()) + .await; // Attempt to inject input into current task if let Err(items) = sess .inject_input(vec![UserInput::Text { @@ -1278,7 +1293,7 @@ async fn submission_loop(sess: Arc, config: Arc, rx_sub: Receiv }]) .await { - sess.spawn_task(Arc::clone(&turn_context), sub.id, items, CompactTask) + sess.spawn_task(Arc::clone(&turn_context), items, CompactTask) .await; } } @@ -1302,14 +1317,14 @@ async fn submission_loop(sess: Arc, config: Arc, rx_sub: Receiv message: "Failed to shutdown rollout recorder".to_string(), }), }; - sess.send_event(event).await; + sess.send_event_raw(event).await; } let event = Event { id: sub.id.clone(), msg: EventMsg::ShutdownComplete, }; - sess.send_event(event).await; + sess.send_event_raw(event).await; break; } Op::GetPath => { @@ -1337,10 +1352,12 @@ async fn submission_loop(sess: Arc, config: Arc, rx_sub: Receiv path, }), }; - sess.send_event(event).await; + sess.send_event_raw(event).await; } Op::Review { review_request } => { - let turn_context = sess.new_turn(SessionSettingsUpdate::default()).await; + let turn_context = sess + .new_turn_with_sub_id(sub.id.clone(), SessionSettingsUpdate::default()) + .await; spawn_review_thread( sess.clone(), config.clone(), @@ -1416,6 +1433,7 @@ async fn spawn_review_thread( ); let review_turn_context = TurnContext { + sub_id: sub_id.to_string(), client, tools_config, user_instructions: None, @@ -1439,17 +1457,11 @@ async fn spawn_review_thread( text: review_prompt, }]; let tc = Arc::new(review_turn_context); - - // Clone sub_id for the upcoming announcement before moving it into the task. - let sub_id_for_event = sub_id.clone(); - sess.spawn_task(tc.clone(), sub_id, input, ReviewTask).await; + sess.spawn_task(tc.clone(), input, ReviewTask).await; // Announce entering review mode so UIs can switch modes. - sess.send_event(Event { - id: sub_id_for_event, - msg: EventMsg::EnteredReviewMode(review_request), - }) - .await; + sess.send_event(&tc, EventMsg::EnteredReviewMode(review_request)) + .await; } /// Takes a user message as input and runs a loop where, at each turn, the model @@ -1472,7 +1484,6 @@ async fn spawn_review_thread( pub(crate) async fn run_task( sess: Arc, turn_context: Arc, - sub_id: String, input: Vec, task_kind: TaskKind, cancellation_token: CancellationToken, @@ -1480,13 +1491,10 @@ pub(crate) async fn run_task( if input.is_empty() { return None; } - let event = Event { - id: sub_id.clone(), - msg: EventMsg::TaskStarted(TaskStartedEvent { - model_context_window: turn_context.client.get_model_context_window(), - }), - }; - sess.send_event(event).await; + let event = EventMsg::TaskStarted(TaskStartedEvent { + model_context_window: turn_context.client.get_model_context_window(), + }); + sess.send_event(&turn_context, event).await; let initial_input_for_turn: ResponseInputItem = ResponseInputItem::from(input); // For review threads, keep an isolated in-memory history so the @@ -1557,7 +1565,6 @@ pub(crate) async fn run_task( Arc::clone(&sess), Arc::clone(&turn_context), Arc::clone(&turn_diff_tracker), - sub_id.clone(), turn_input, task_kind, cancellation_token.child_token(), @@ -1689,15 +1696,12 @@ pub(crate) async fn run_task( let current_tokens = total_usage_tokens .map(|tokens| tokens.to_string()) .unwrap_or_else(|| "unknown".to_string()); - let event = Event { - id: sub_id.clone(), - msg: EventMsg::Error(ErrorEvent { - message: format!( - "Conversation is still above the token limit after automatic summarization (limit {limit_str}, current {current_tokens}). Please start a new session or trim your input." - ), - }), - }; - sess.send_event(event).await; + let event = EventMsg::Error(ErrorEvent { + message: format!( + "Conversation is still above the token limit after automatic summarization (limit {limit_str}, current {current_tokens}). Please start a new session or trim your input." + ), + }); + sess.send_event(&turn_context, event).await; break; } auto_compact_recently_attempted = true; @@ -1714,7 +1718,7 @@ pub(crate) async fn run_task( sess.notifier() .notify(&UserNotification::AgentTurnComplete { thread_id: sess.conversation_id.to_string(), - turn_id: sub_id.clone(), + turn_id: turn_context.sub_id.clone(), cwd: turn_context.cwd.display().to_string(), input_messages: turn_input_messages, last_assistant_message: last_agent_message.clone(), @@ -1729,13 +1733,10 @@ pub(crate) async fn run_task( } Err(e) => { info!("Turn error: {e:#}"); - let event = Event { - id: sub_id.clone(), - msg: EventMsg::Error(ErrorEvent { - message: e.to_string(), - }), - }; - sess.send_event(event).await; + let event = EventMsg::Error(ErrorEvent { + message: e.to_string(), + }); + sess.send_event(&turn_context, event).await; // let the user continue the conversation break; } @@ -1752,7 +1753,7 @@ pub(crate) async fn run_task( if turn_context.is_review_mode { exit_review_mode( sess.clone(), - sub_id.clone(), + Arc::clone(&turn_context), last_agent_message.as_deref().map(parse_review_output_event), ) .await; @@ -1790,7 +1791,6 @@ async fn run_turn( sess: Arc, turn_context: Arc, turn_diff_tracker: SharedTurnDiffTracker, - sub_id: String, input: Vec, task_kind: TaskKind, cancellation_token: CancellationToken, @@ -1821,7 +1821,6 @@ async fn run_turn( Arc::clone(&sess), Arc::clone(&turn_context), Arc::clone(&turn_diff_tracker), - &sub_id, &prompt, task_kind, cancellation_token.child_token(), @@ -1834,13 +1833,14 @@ async fn run_turn( Err(CodexErr::EnvVar(var)) => return Err(CodexErr::EnvVar(var)), Err(e @ CodexErr::Fatal(_)) => return Err(e), Err(e @ CodexErr::ContextWindowExceeded) => { - sess.set_total_tokens_full(&sub_id, &turn_context).await; + sess.set_total_tokens_full(turn_context.as_ref()).await; return Err(e); } Err(CodexErr::UsageLimitReached(e)) => { let rate_limits = e.rate_limits.clone(); if let Some(rate_limits) = rate_limits { - sess.update_rate_limits(&sub_id, rate_limits).await; + sess.update_rate_limits(turn_context.as_ref(), rate_limits) + .await; } return Err(CodexErr::UsageLimitReached(e)); } @@ -1862,7 +1862,7 @@ async fn run_turn( // user understands what is happening instead of staring // at a seemingly frozen screen. sess.notify_stream_error( - &sub_id, + turn_context.as_ref(), format!("Re-connecting... {retries}/{max_retries}"), ) .await; @@ -1898,7 +1898,6 @@ async fn try_run_turn( sess: Arc, turn_context: Arc, turn_diff_tracker: SharedTurnDiffTracker, - sub_id: &str, prompt: &Prompt, task_kind: TaskKind, cancellation_token: CancellationToken, @@ -1979,7 +1978,6 @@ async fn try_run_turn( Arc::clone(&sess), Arc::clone(&turn_context), Arc::clone(&turn_diff_tracker), - sub_id.to_string(), ); let mut output: FuturesOrdered>> = FuturesOrdered::new(); @@ -2028,7 +2026,6 @@ async fn try_run_turn( let response = handle_non_tool_response_item( Arc::clone(&sess), Arc::clone(&turn_context), - sub_id, item.clone(), ) .await?; @@ -2077,7 +2074,7 @@ async fn try_run_turn( let _ = sess .tx_event .send(Event { - id: sub_id.to_string(), + id: turn_context.sub_id.clone(), msg: EventMsg::WebSearchBegin(WebSearchBeginEvent { call_id }), }) .await; @@ -2085,13 +2082,14 @@ async fn try_run_turn( ResponseEvent::RateLimits(snapshot) => { // Update internal state with latest rate limits, but defer sending until // token usage is available to avoid duplicate TokenCount events. - sess.update_rate_limits(sub_id, snapshot).await; + sess.update_rate_limits(turn_context.as_ref(), snapshot) + .await; } ResponseEvent::Completed { response_id: _, token_usage, } => { - sess.update_token_usage_info(sub_id, turn_context.as_ref(), token_usage.as_ref()) + sess.update_token_usage_info(turn_context.as_ref(), token_usage.as_ref()) .await; let processed_items = output @@ -2105,11 +2103,7 @@ async fn try_run_turn( }; if let Ok(Some(unified_diff)) = unified_diff { let msg = EventMsg::TurnDiff(TurnDiffEvent { unified_diff }); - let event = Event { - id: sub_id.to_string(), - msg, - }; - sess.send_event(event).await; + sess.send_event(&turn_context, msg).await; } let result = TurnRunResult { @@ -2123,38 +2117,27 @@ async fn try_run_turn( // In review child threads, suppress assistant text deltas; the // UI will show a selection popup from the final ReviewOutput. if !turn_context.is_review_mode { - let event = Event { - id: sub_id.to_string(), - msg: EventMsg::AgentMessageDelta(AgentMessageDeltaEvent { delta }), - }; - sess.send_event(event).await; + let event = EventMsg::AgentMessageDelta(AgentMessageDeltaEvent { delta }); + sess.send_event(&turn_context, event).await; } else { trace!("suppressing OutputTextDelta in review mode"); } } ResponseEvent::ReasoningSummaryDelta(delta) => { - let event = Event { - id: sub_id.to_string(), - msg: EventMsg::AgentReasoningDelta(AgentReasoningDeltaEvent { delta }), - }; - sess.send_event(event).await; + let event = EventMsg::AgentReasoningDelta(AgentReasoningDeltaEvent { delta }); + sess.send_event(&turn_context, event).await; } ResponseEvent::ReasoningSummaryPartAdded => { - let event = Event { - id: sub_id.to_string(), - msg: EventMsg::AgentReasoningSectionBreak(AgentReasoningSectionBreakEvent {}), - }; - sess.send_event(event).await; + let event = + EventMsg::AgentReasoningSectionBreak(AgentReasoningSectionBreakEvent {}); + sess.send_event(&turn_context, event).await; } ResponseEvent::ReasoningContentDelta(delta) => { if sess.show_raw_agent_reasoning() { - let event = Event { - id: sub_id.to_string(), - msg: EventMsg::AgentReasoningRawContentDelta( - AgentReasoningRawContentDeltaEvent { delta }, - ), - }; - sess.send_event(event).await; + let event = EventMsg::AgentReasoningRawContentDelta( + AgentReasoningRawContentDeltaEvent { delta }, + ); + sess.send_event(&turn_context, event).await; } } } @@ -2164,7 +2147,6 @@ async fn try_run_turn( async fn handle_non_tool_response_item( sess: Arc, turn_context: Arc, - sub_id: &str, item: ResponseItem, ) -> CodexResult> { debug!(?item, "Output item"); @@ -2181,11 +2163,7 @@ async fn handle_non_tool_response_item( _ => map_response_item_to_event_messages(&item, sess.show_raw_agent_reasoning()), }; for msg in msgs { - let event = Event { - id: sub_id.to_string(), - msg, - }; - sess.send_event(event).await; + sess.send_event(&turn_context, msg).await; } } ResponseItem::FunctionCallOutput { .. } | ResponseItem::CustomToolCallOutput { .. } => { @@ -2255,16 +2233,13 @@ fn convert_call_tool_result_to_function_call_output_payload( /// and records a developer message with the review output. pub(crate) async fn exit_review_mode( session: Arc, - task_sub_id: String, + turn_context: Arc, review_output: Option, ) { - let event = Event { - id: task_sub_id, - msg: EventMsg::ExitedReviewMode(ExitedReviewModeEvent { - review_output: review_output.clone(), - }), - }; - session.send_event(event).await; + let event = EventMsg::ExitedReviewMode(ExitedReviewModeEvent { + review_output: review_output.clone(), + }); + session.send_event(turn_context.as_ref(), event).await; let mut user_message = String::new(); if let Some(out) = review_output { @@ -2676,6 +2651,7 @@ mod tests { &session_configuration, conversation_id, tx_event.clone(), + "turn_id".to_string(), ); let session = Session { @@ -2744,6 +2720,7 @@ mod tests { &session_configuration, conversation_id, tx_event.clone(), + "turn_id".to_string(), )); let session = Arc::new(Session { @@ -2774,7 +2751,6 @@ mod tests { self: Arc, _session: Arc, _ctx: Arc, - _sub_id: String, _input: Vec, cancellation_token: CancellationToken, ) -> Option { @@ -2787,9 +2763,9 @@ mod tests { } } - async fn abort(&self, session: Arc, sub_id: &str) { + async fn abort(&self, session: Arc, ctx: Arc) { if let TaskKind::Review = self.kind { - exit_review_mode(session.clone_session(), sub_id.to_string(), None).await; + exit_review_mode(session.clone_session(), ctx, None).await; } } } @@ -2798,13 +2774,11 @@ mod tests { #[test_log::test] async fn abort_regular_task_emits_turn_aborted_only() { let (sess, tc, rx) = make_session_and_context_with_rx(); - let sub_id = "sub-regular".to_string(); let input = vec![UserInput::Text { text: "hello".to_string(), }]; sess.spawn_task( Arc::clone(&tc), - sub_id.clone(), input, NeverEndingTask { kind: TaskKind::Regular, @@ -2829,13 +2803,11 @@ mod tests { #[tokio::test] async fn abort_gracefuly_emits_turn_aborted_only() { let (sess, tc, rx) = make_session_and_context_with_rx(); - let sub_id = "sub-regular".to_string(); let input = vec![UserInput::Text { text: "hello".to_string(), }]; sess.spawn_task( Arc::clone(&tc), - sub_id.clone(), input, NeverEndingTask { kind: TaskKind::Regular, @@ -2857,13 +2829,11 @@ mod tests { #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn abort_review_task_emits_exited_then_aborted_and_records_history() { let (sess, tc, rx) = make_session_and_context_with_rx(); - let sub_id = "sub-review".to_string(); let input = vec![UserInput::Text { text: "start review".to_string(), }]; sess.spawn_task( Arc::clone(&tc), - sub_id.clone(), input, NeverEndingTask { kind: TaskKind::Review, @@ -2935,7 +2905,6 @@ mod tests { Arc::clone(&session), Arc::clone(&turn_context), tracker, - "sub-id".to_string(), call, ) .await @@ -3095,7 +3064,6 @@ mod tests { let turn_diff_tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::new())); let tool_name = "shell"; - let sub_id = "test-sub".to_string(); let call_id = "test-call".to_string(); let resp = handle_container_exec_with_params( @@ -3104,7 +3072,6 @@ mod tests { Arc::clone(&session), Arc::clone(&turn_context), Arc::clone(&turn_diff_tracker), - sub_id, call_id, ) .await; @@ -3132,7 +3099,6 @@ mod tests { Arc::clone(&session), Arc::clone(&turn_context), Arc::clone(&turn_diff_tracker), - "test-sub".to_string(), "test-call-2".to_string(), ) .await; diff --git a/codex-rs/core/src/codex/compact.rs b/codex-rs/core/src/codex/compact.rs index a6bb8a8f..bbc1b976 100644 --- a/codex-rs/core/src/codex/compact.rs +++ b/codex-rs/core/src/codex/compact.rs @@ -10,7 +10,6 @@ use crate::error::Result as CodexResult; use crate::protocol::AgentMessageEvent; use crate::protocol::CompactedItem; use crate::protocol::ErrorEvent; -use crate::protocol::Event; use crate::protocol::EventMsg; use crate::protocol::InputMessageKind; use crate::protocol::TaskStartedEvent; @@ -40,34 +39,28 @@ pub(crate) async fn run_inline_auto_compact_task( sess: Arc, turn_context: Arc, ) { - let sub_id = sess.next_internal_sub_id(); let input = vec![UserInput::Text { text: SUMMARIZATION_PROMPT.to_string(), }]; - run_compact_task_inner(sess, turn_context, sub_id, input).await; + run_compact_task_inner(sess, turn_context, input).await; } pub(crate) async fn run_compact_task( sess: Arc, turn_context: Arc, - sub_id: String, input: Vec, ) -> Option { - let start_event = Event { - id: sub_id.clone(), - msg: EventMsg::TaskStarted(TaskStartedEvent { - model_context_window: turn_context.client.get_model_context_window(), - }), - }; - sess.send_event(start_event).await; - run_compact_task_inner(sess.clone(), turn_context, sub_id.clone(), input).await; + let start_event = EventMsg::TaskStarted(TaskStartedEvent { + model_context_window: turn_context.client.get_model_context_window(), + }); + sess.send_event(&turn_context, start_event).await; + run_compact_task_inner(sess.clone(), turn_context, input).await; None } async fn run_compact_task_inner( sess: Arc, turn_context: Arc, - sub_id: String, input: Vec, ) { let initial_input_for_turn: ResponseInputItem = ResponseInputItem::from(input); @@ -94,14 +87,13 @@ async fn run_compact_task_inner( input: turn_input.clone(), ..Default::default() }; - let attempt_result = - drain_to_completed(&sess, turn_context.as_ref(), &sub_id, &prompt).await; + let attempt_result = drain_to_completed(&sess, turn_context.as_ref(), &prompt).await; match attempt_result { Ok(()) => { if truncated_count > 0 { sess.notify_background_event( - &sub_id, + turn_context.as_ref(), format!( "Trimmed {truncated_count} older conversation item(s) before compacting so the prompt fits the model context window." ), @@ -120,15 +112,11 @@ async fn run_compact_task_inner( retries = 0; continue; } - sess.set_total_tokens_full(&sub_id, turn_context.as_ref()) - .await; - let event = Event { - id: sub_id.clone(), - msg: EventMsg::Error(ErrorEvent { - message: e.to_string(), - }), - }; - sess.send_event(event).await; + sess.set_total_tokens_full(turn_context.as_ref()).await; + let event = EventMsg::Error(ErrorEvent { + message: e.to_string(), + }); + sess.send_event(&turn_context, event).await; return; } Err(e) => { @@ -136,20 +124,17 @@ async fn run_compact_task_inner( retries += 1; let delay = backoff(retries); sess.notify_stream_error( - &sub_id, + turn_context.as_ref(), format!("Re-connecting... {retries}/{max_retries}"), ) .await; tokio::time::sleep(delay).await; continue; } else { - let event = Event { - id: sub_id.clone(), - msg: EventMsg::Error(ErrorEvent { - message: e.to_string(), - }), - }; - sess.send_event(event).await; + let event = EventMsg::Error(ErrorEvent { + message: e.to_string(), + }); + sess.send_event(&turn_context, event).await; return; } } @@ -168,13 +153,10 @@ async fn run_compact_task_inner( }); sess.persist_rollout_items(&[rollout_item]).await; - let event = Event { - id: sub_id.clone(), - msg: EventMsg::AgentMessage(AgentMessageEvent { - message: "Compact task completed".to_string(), - }), - }; - sess.send_event(event).await; + let event = EventMsg::AgentMessage(AgentMessageEvent { + message: "Compact task completed".to_string(), + }); + sess.send_event(&turn_context, event).await; } pub fn content_items_to_text(content: &[ContentItem]) -> Option { @@ -256,7 +238,6 @@ pub(crate) fn build_compacted_history( async fn drain_to_completed( sess: &Session, turn_context: &TurnContext, - sub_id: &str, prompt: &Prompt, ) -> CodexResult<()> { let mut stream = turn_context @@ -277,10 +258,10 @@ async fn drain_to_completed( sess.record_into_history(std::slice::from_ref(&item)).await; } Ok(ResponseEvent::RateLimits(snapshot)) => { - sess.update_rate_limits(sub_id, snapshot).await; + sess.update_rate_limits(turn_context, snapshot).await; } Ok(ResponseEvent::Completed { token_usage, .. }) => { - sess.update_token_usage_info(sub_id, turn_context, token_usage.as_ref()) + sess.update_token_usage_info(turn_context, token_usage.as_ref()) .await; return Ok(()); } diff --git a/codex-rs/core/src/mcp_tool_call.rs b/codex-rs/core/src/mcp_tool_call.rs index 21fdaf00..09846b71 100644 --- a/codex-rs/core/src/mcp_tool_call.rs +++ b/codex-rs/core/src/mcp_tool_call.rs @@ -3,7 +3,7 @@ use std::time::Instant; use tracing::error; use crate::codex::Session; -use crate::protocol::Event; +use crate::codex::TurnContext; use crate::protocol::EventMsg; use crate::protocol::McpInvocation; use crate::protocol::McpToolCallBeginEvent; @@ -15,7 +15,7 @@ use codex_protocol::models::ResponseInputItem; /// `McpToolCallBegin` and `McpToolCallEnd` events to the `Session`. pub(crate) async fn handle_mcp_tool_call( sess: &Session, - sub_id: &str, + turn_context: &TurnContext, call_id: String, server: String, tool_name: String, @@ -51,7 +51,7 @@ pub(crate) async fn handle_mcp_tool_call( call_id: call_id.clone(), invocation: invocation.clone(), }); - notify_mcp_tool_call_event(sess, sub_id, tool_call_begin_event).await; + notify_mcp_tool_call_event(sess, turn_context, tool_call_begin_event).await; let start = Instant::now(); // Perform the tool call. @@ -69,15 +69,11 @@ pub(crate) async fn handle_mcp_tool_call( result: result.clone(), }); - notify_mcp_tool_call_event(sess, sub_id, tool_call_end_event.clone()).await; + notify_mcp_tool_call_event(sess, turn_context, tool_call_end_event.clone()).await; ResponseInputItem::McpToolCallOutput { call_id, result } } -async fn notify_mcp_tool_call_event(sess: &Session, sub_id: &str, event: EventMsg) { - sess.send_event(Event { - id: sub_id.to_string(), - msg: event, - }) - .await; +async fn notify_mcp_tool_call_event(sess: &Session, turn_context: &TurnContext, event: EventMsg) { + sess.send_event(turn_context, event).await; } diff --git a/codex-rs/core/src/state/turn.rs b/codex-rs/core/src/state/turn.rs index 3ed63c5b..e2ed6f30 100644 --- a/codex-rs/core/src/state/turn.rs +++ b/codex-rs/core/src/state/turn.rs @@ -11,6 +11,7 @@ use tokio_util::task::AbortOnDropHandle; use codex_protocol::models::ResponseInputItem; use tokio::sync::oneshot; +use crate::codex::TurnContext; use crate::protocol::ReviewDecision; use crate::tasks::SessionTask; @@ -53,10 +54,12 @@ pub(crate) struct RunningTask { pub(crate) task: Arc, pub(crate) cancellation_token: CancellationToken, pub(crate) handle: Arc>, + pub(crate) turn_context: Arc, } impl ActiveTurn { - pub(crate) fn add_task(&mut self, sub_id: String, task: RunningTask) { + pub(crate) fn add_task(&mut self, task: RunningTask) { + let sub_id = task.turn_context.sub_id.clone(); self.tasks.insert(sub_id, task); } @@ -65,8 +68,8 @@ impl ActiveTurn { self.tasks.is_empty() } - pub(crate) fn drain_tasks(&mut self) -> IndexMap { - std::mem::take(&mut self.tasks) + pub(crate) fn drain_tasks(&mut self) -> Vec { + self.tasks.drain(..).map(|(_, task)| task).collect() } } diff --git a/codex-rs/core/src/tasks/compact.rs b/codex-rs/core/src/tasks/compact.rs index a27e68dd..64b2a9d2 100644 --- a/codex-rs/core/src/tasks/compact.rs +++ b/codex-rs/core/src/tasks/compact.rs @@ -24,10 +24,9 @@ impl SessionTask for CompactTask { self: Arc, session: Arc, ctx: Arc, - sub_id: String, input: Vec, _cancellation_token: CancellationToken, ) -> Option { - compact::run_compact_task(session.clone_session(), ctx, sub_id, input).await + compact::run_compact_task(session.clone_session(), ctx, input).await } } diff --git a/codex-rs/core/src/tasks/mod.rs b/codex-rs/core/src/tasks/mod.rs index 6a5dff6a..79527814 100644 --- a/codex-rs/core/src/tasks/mod.rs +++ b/codex-rs/core/src/tasks/mod.rs @@ -15,7 +15,6 @@ use tracing::warn; use crate::codex::Session; use crate::codex::TurnContext; -use crate::protocol::Event; use crate::protocol::EventMsg; use crate::protocol::TaskCompleteEvent; use crate::protocol::TurnAbortReason; @@ -55,13 +54,12 @@ pub(crate) trait SessionTask: Send + Sync + 'static { self: Arc, session: Arc, ctx: Arc, - sub_id: String, input: Vec, cancellation_token: CancellationToken, ) -> Option; - async fn abort(&self, session: Arc, sub_id: &str) { - let _ = (session, sub_id); + async fn abort(&self, session: Arc, ctx: Arc) { + let _ = (session, ctx); } } @@ -69,7 +67,6 @@ impl Session { pub async fn spawn_task( self: &Arc, turn_context: Arc, - sub_id: String, input: Vec, task: T, ) { @@ -86,14 +83,13 @@ impl Session { let session_ctx = Arc::new(SessionTaskContext::new(Arc::clone(self))); let ctx = Arc::clone(&turn_context); let task_for_run = Arc::clone(&task); - let sub_clone = sub_id.clone(); let task_cancellation_token = cancellation_token.child_token(); tokio::spawn(async move { + let ctx_for_finish = Arc::clone(&ctx); let last_agent_message = task_for_run .run( Arc::clone(&session_ctx), ctx, - sub_clone.clone(), input, task_cancellation_token.child_token(), ) @@ -102,7 +98,8 @@ impl Session { if !task_cancellation_token.is_cancelled() { // Emit completion uniformly from spawn site so all tasks share the same lifecycle. let sess = session_ctx.clone_session(); - sess.on_task_finished(sub_clone, last_agent_message).await; + sess.on_task_finished(ctx_for_finish, last_agent_message) + .await; } done_clone.notify_waiters(); }) @@ -114,60 +111,54 @@ impl Session { kind: task_kind, task, cancellation_token, + turn_context: Arc::clone(&turn_context), }; - self.register_new_active_task(sub_id, running_task).await; + self.register_new_active_task(running_task).await; } pub async fn abort_all_tasks(self: &Arc, reason: TurnAbortReason) { - for (sub_id, task) in self.take_all_running_tasks().await { - self.handle_task_abort(sub_id, task, reason.clone()).await; + for task in self.take_all_running_tasks().await { + self.handle_task_abort(task, reason.clone()).await; } } pub async fn on_task_finished( self: &Arc, - sub_id: String, + turn_context: Arc, last_agent_message: Option, ) { let mut active = self.active_turn.lock().await; if let Some(at) = active.as_mut() - && at.remove_task(&sub_id) + && at.remove_task(&turn_context.sub_id) { *active = None; } drop(active); - let event = Event { - id: sub_id, - msg: EventMsg::TaskComplete(TaskCompleteEvent { last_agent_message }), - }; - self.send_event(event).await; + let event = EventMsg::TaskComplete(TaskCompleteEvent { last_agent_message }); + self.send_event(turn_context.as_ref(), event).await; } - async fn register_new_active_task(&self, sub_id: String, task: RunningTask) { + async fn register_new_active_task(&self, task: RunningTask) { let mut active = self.active_turn.lock().await; let mut turn = ActiveTurn::default(); - turn.add_task(sub_id, task); + turn.add_task(task); *active = Some(turn); } - async fn take_all_running_tasks(&self) -> Vec<(String, RunningTask)> { + async fn take_all_running_tasks(&self) -> Vec { let mut active = self.active_turn.lock().await; match active.take() { Some(mut at) => { at.clear_pending().await; - let tasks = at.drain_tasks(); - tasks.into_iter().collect() + + at.drain_tasks() } None => Vec::new(), } } - async fn handle_task_abort( - self: &Arc, - sub_id: String, - task: RunningTask, - reason: TurnAbortReason, - ) { + async fn handle_task_abort(self: &Arc, task: RunningTask, reason: TurnAbortReason) { + let sub_id = task.turn_context.sub_id.clone(); if task.cancellation_token.is_cancelled() { return; } @@ -187,13 +178,12 @@ impl Session { task.handle.abort(); let session_ctx = Arc::new(SessionTaskContext::new(Arc::clone(self))); - session_task.abort(session_ctx, &sub_id).await; + session_task + .abort(session_ctx, Arc::clone(&task.turn_context)) + .await; - let event = Event { - id: sub_id.clone(), - msg: EventMsg::TurnAborted(TurnAbortedEvent { reason }), - }; - self.send_event(event).await; + let event = EventMsg::TurnAborted(TurnAbortedEvent { reason }); + self.send_event(task.turn_context.as_ref(), event).await; } } diff --git a/codex-rs/core/src/tasks/regular.rs b/codex-rs/core/src/tasks/regular.rs index a79d842c..58ecedd4 100644 --- a/codex-rs/core/src/tasks/regular.rs +++ b/codex-rs/core/src/tasks/regular.rs @@ -24,19 +24,10 @@ impl SessionTask for RegularTask { self: Arc, session: Arc, ctx: Arc, - sub_id: String, input: Vec, cancellation_token: CancellationToken, ) -> Option { let sess = session.clone_session(); - run_task( - sess, - ctx, - sub_id, - input, - TaskKind::Regular, - cancellation_token, - ) - .await + run_task(sess, ctx, input, TaskKind::Regular, cancellation_token).await } } diff --git a/codex-rs/core/src/tasks/review.rs b/codex-rs/core/src/tasks/review.rs index e4dfeec6..fbf553b3 100644 --- a/codex-rs/core/src/tasks/review.rs +++ b/codex-rs/core/src/tasks/review.rs @@ -25,23 +25,14 @@ impl SessionTask for ReviewTask { self: Arc, session: Arc, ctx: Arc, - sub_id: String, input: Vec, cancellation_token: CancellationToken, ) -> Option { let sess = session.clone_session(); - run_task( - sess, - ctx, - sub_id, - input, - TaskKind::Review, - cancellation_token, - ) - .await + run_task(sess, ctx, input, TaskKind::Review, cancellation_token).await } - async fn abort(&self, session: Arc, sub_id: &str) { - exit_review_mode(session.clone_session(), sub_id.to_string(), None).await; + async fn abort(&self, session: Arc, ctx: Arc) { + exit_review_mode(session.clone_session(), ctx, None).await; } } diff --git a/codex-rs/core/src/tools/context.rs b/codex-rs/core/src/tools/context.rs index a262404b..27d309dc 100644 --- a/codex-rs/core/src/tools/context.rs +++ b/codex-rs/core/src/tools/context.rs @@ -24,7 +24,6 @@ pub struct ToolInvocation { pub session: Arc, pub turn: Arc, pub tracker: SharedTurnDiffTracker, - pub sub_id: String, pub call_id: String, pub tool_name: String, pub payload: ToolPayload, @@ -234,7 +233,7 @@ mod tests { #[derive(Clone, Debug)] #[allow(dead_code)] pub(crate) struct ExecCommandContext { - pub(crate) sub_id: String, + pub(crate) turn: Arc, pub(crate) call_id: String, pub(crate) command_for_display: Vec, pub(crate) cwd: PathBuf, diff --git a/codex-rs/core/src/tools/events.rs b/codex-rs/core/src/tools/events.rs index e09f266b..af8afe3e 100644 --- a/codex-rs/core/src/tools/events.rs +++ b/codex-rs/core/src/tools/events.rs @@ -1,7 +1,7 @@ use crate::codex::Session; +use crate::codex::TurnContext; use crate::exec::ExecToolCallOutput; use crate::parse_command::parse_command; -use crate::protocol::Event; use crate::protocol::EventMsg; use crate::protocol::ExecCommandBeginEvent; use crate::protocol::ExecCommandEndEvent; @@ -20,7 +20,7 @@ use super::format_exec_output_str; #[derive(Clone, Copy)] pub(crate) struct ToolEventCtx<'a> { pub session: &'a Session, - pub sub_id: &'a str, + pub turn: &'a TurnContext, pub call_id: &'a str, pub turn_diff_tracker: Option<&'a SharedTurnDiffTracker>, } @@ -28,13 +28,13 @@ pub(crate) struct ToolEventCtx<'a> { impl<'a> ToolEventCtx<'a> { pub fn new( session: &'a Session, - sub_id: &'a str, + turn: &'a TurnContext, call_id: &'a str, turn_diff_tracker: Option<&'a SharedTurnDiffTracker>, ) -> Self { Self { session, - sub_id, + turn, call_id, turn_diff_tracker, } @@ -79,15 +79,15 @@ impl ToolEmitter { match (self, stage) { (Self::Shell { command, cwd }, ToolEventStage::Begin) => { ctx.session - .send_event(Event { - id: ctx.sub_id.to_string(), - msg: EventMsg::ExecCommandBegin(ExecCommandBeginEvent { + .send_event( + ctx.turn, + EventMsg::ExecCommandBegin(ExecCommandBeginEvent { call_id: ctx.call_id.to_string(), command: command.clone(), cwd: cwd.clone(), parsed_cmd: parse_command(command), }), - }) + ) .await; } (Self::Shell { .. }, ToolEventStage::Success(output)) => { @@ -139,14 +139,14 @@ impl ToolEmitter { guard.on_patch_begin(changes); } ctx.session - .send_event(Event { - id: ctx.sub_id.to_string(), - msg: EventMsg::PatchApplyBegin(PatchApplyBeginEvent { + .send_event( + ctx.turn, + EventMsg::PatchApplyBegin(PatchApplyBeginEvent { call_id: ctx.call_id.to_string(), auto_approved: *auto_approved, changes: changes.clone(), }), - }) + ) .await; } (Self::ApplyPatch { .. }, ToolEventStage::Success(output)) => { @@ -190,9 +190,9 @@ async fn emit_exec_end( formatted_output: String, ) { ctx.session - .send_event(Event { - id: ctx.sub_id.to_string(), - msg: EventMsg::ExecCommandEnd(ExecCommandEndEvent { + .send_event( + ctx.turn, + EventMsg::ExecCommandEnd(ExecCommandEndEvent { call_id: ctx.call_id.to_string(), stdout, stderr, @@ -201,21 +201,21 @@ async fn emit_exec_end( duration, formatted_output, }), - }) + ) .await; } async fn emit_patch_end(ctx: ToolEventCtx<'_>, stdout: String, stderr: String, success: bool) { ctx.session - .send_event(Event { - id: ctx.sub_id.to_string(), - msg: EventMsg::PatchApplyEnd(PatchApplyEndEvent { + .send_event( + ctx.turn, + EventMsg::PatchApplyEnd(PatchApplyEndEvent { call_id: ctx.call_id.to_string(), stdout, stderr, success, }), - }) + ) .await; if let Some(tracker) = ctx.turn_diff_tracker { @@ -225,10 +225,7 @@ async fn emit_patch_end(ctx: ToolEventCtx<'_>, stdout: String, stderr: String, s }; if let Ok(Some(unified_diff)) = unified_diff { ctx.session - .send_event(Event { - id: ctx.sub_id.to_string(), - msg: EventMsg::TurnDiff(TurnDiffEvent { unified_diff }), - }) + .send_event(ctx.turn, EventMsg::TurnDiff(TurnDiffEvent { unified_diff })) .await; } } diff --git a/codex-rs/core/src/tools/handlers/apply_patch.rs b/codex-rs/core/src/tools/handlers/apply_patch.rs index 4223740c..d91db362 100644 --- a/codex-rs/core/src/tools/handlers/apply_patch.rs +++ b/codex-rs/core/src/tools/handlers/apply_patch.rs @@ -42,7 +42,6 @@ impl ToolHandler for ApplyPatchHandler { session, turn, tracker, - sub_id, call_id, tool_name, payload, @@ -81,7 +80,6 @@ impl ToolHandler for ApplyPatchHandler { Arc::clone(&session), Arc::clone(&turn), Arc::clone(&tracker), - sub_id.clone(), call_id.clone(), ) .await?; diff --git a/codex-rs/core/src/tools/handlers/mcp.rs b/codex-rs/core/src/tools/handlers/mcp.rs index ba95a5ea..4b2bf3b8 100644 --- a/codex-rs/core/src/tools/handlers/mcp.rs +++ b/codex-rs/core/src/tools/handlers/mcp.rs @@ -19,7 +19,7 @@ impl ToolHandler for McpHandler { async fn handle(&self, invocation: ToolInvocation) -> Result { let ToolInvocation { session, - sub_id, + turn, call_id, payload, .. @@ -43,7 +43,7 @@ impl ToolHandler for McpHandler { let response = handle_mcp_tool_call( session.as_ref(), - &sub_id, + turn.as_ref(), call_id.clone(), server, tool, diff --git a/codex-rs/core/src/tools/handlers/mcp_resource.rs b/codex-rs/core/src/tools/handlers/mcp_resource.rs index 7e425aad..be496f01 100644 --- a/codex-rs/core/src/tools/handlers/mcp_resource.rs +++ b/codex-rs/core/src/tools/handlers/mcp_resource.rs @@ -21,8 +21,8 @@ use serde::de::DeserializeOwned; use serde_json::Value; use crate::codex::Session; +use crate::codex::TurnContext; use crate::function_tool::FunctionCallError; -use crate::protocol::Event; use crate::protocol::EventMsg; use crate::protocol::McpInvocation; use crate::protocol::McpToolCallBeginEvent; @@ -189,7 +189,7 @@ impl ToolHandler for McpResourceHandler { async fn handle(&self, invocation: ToolInvocation) -> Result { let ToolInvocation { session, - sub_id, + turn, call_id, tool_name, payload, @@ -211,7 +211,7 @@ impl ToolHandler for McpResourceHandler { "list_mcp_resources" => { handle_list_resources( Arc::clone(&session), - sub_id.clone(), + Arc::clone(&turn), call_id.clone(), arguments_value.clone(), ) @@ -220,14 +220,20 @@ impl ToolHandler for McpResourceHandler { "list_mcp_resource_templates" => { handle_list_resource_templates( Arc::clone(&session), - sub_id.clone(), + Arc::clone(&turn), call_id.clone(), arguments_value.clone(), ) .await } "read_mcp_resource" => { - handle_read_resource(Arc::clone(&session), sub_id, call_id, arguments_value).await + handle_read_resource( + Arc::clone(&session), + Arc::clone(&turn), + call_id, + arguments_value, + ) + .await } other => Err(FunctionCallError::RespondToModel(format!( "unsupported MCP resource tool: {other}" @@ -238,7 +244,7 @@ impl ToolHandler for McpResourceHandler { async fn handle_list_resources( session: Arc, - sub_id: String, + turn: Arc, call_id: String, arguments: Option, ) -> Result { @@ -253,7 +259,7 @@ async fn handle_list_resources( arguments: arguments.clone(), }; - emit_tool_call_begin(&session, &sub_id, &call_id, invocation.clone()).await; + emit_tool_call_begin(&session, turn.as_ref(), &call_id, invocation.clone()).await; let start = Instant::now(); let payload_result: Result = async { @@ -297,7 +303,7 @@ async fn handle_list_resources( let duration = start.elapsed(); emit_tool_call_end( &session, - &sub_id, + turn.as_ref(), &call_id, invocation, duration, @@ -311,7 +317,7 @@ async fn handle_list_resources( let message = err.to_string(); emit_tool_call_end( &session, - &sub_id, + turn.as_ref(), &call_id, invocation, duration, @@ -326,7 +332,7 @@ async fn handle_list_resources( let message = err.to_string(); emit_tool_call_end( &session, - &sub_id, + turn.as_ref(), &call_id, invocation, duration, @@ -340,7 +346,7 @@ async fn handle_list_resources( async fn handle_list_resource_templates( session: Arc, - sub_id: String, + turn: Arc, call_id: String, arguments: Option, ) -> Result { @@ -355,7 +361,7 @@ async fn handle_list_resource_templates( arguments: arguments.clone(), }; - emit_tool_call_begin(&session, &sub_id, &call_id, invocation.clone()).await; + emit_tool_call_begin(&session, turn.as_ref(), &call_id, invocation.clone()).await; let start = Instant::now(); let payload_result: Result = async { @@ -403,7 +409,7 @@ async fn handle_list_resource_templates( let duration = start.elapsed(); emit_tool_call_end( &session, - &sub_id, + turn.as_ref(), &call_id, invocation, duration, @@ -417,7 +423,7 @@ async fn handle_list_resource_templates( let message = err.to_string(); emit_tool_call_end( &session, - &sub_id, + turn.as_ref(), &call_id, invocation, duration, @@ -432,7 +438,7 @@ async fn handle_list_resource_templates( let message = err.to_string(); emit_tool_call_end( &session, - &sub_id, + turn.as_ref(), &call_id, invocation, duration, @@ -446,7 +452,7 @@ async fn handle_list_resource_templates( async fn handle_read_resource( session: Arc, - sub_id: String, + turn: Arc, call_id: String, arguments: Option, ) -> Result { @@ -461,7 +467,7 @@ async fn handle_read_resource( arguments: arguments.clone(), }; - emit_tool_call_begin(&session, &sub_id, &call_id, invocation.clone()).await; + emit_tool_call_begin(&session, turn.as_ref(), &call_id, invocation.clone()).await; let start = Instant::now(); let payload_result: Result = async { @@ -489,7 +495,7 @@ async fn handle_read_resource( let duration = start.elapsed(); emit_tool_call_end( &session, - &sub_id, + turn.as_ref(), &call_id, invocation, duration, @@ -503,7 +509,7 @@ async fn handle_read_resource( let message = err.to_string(); emit_tool_call_end( &session, - &sub_id, + turn.as_ref(), &call_id, invocation, duration, @@ -518,7 +524,7 @@ async fn handle_read_resource( let message = err.to_string(); emit_tool_call_end( &session, - &sub_id, + turn.as_ref(), &call_id, invocation, duration, @@ -544,39 +550,39 @@ fn call_tool_result_from_content(content: &str, success: Option) -> CallTo async fn emit_tool_call_begin( session: &Arc, - sub_id: &str, + turn: &TurnContext, call_id: &str, invocation: McpInvocation, ) { session - .send_event(Event { - id: sub_id.to_string(), - msg: EventMsg::McpToolCallBegin(McpToolCallBeginEvent { + .send_event( + turn, + EventMsg::McpToolCallBegin(McpToolCallBeginEvent { call_id: call_id.to_string(), invocation, }), - }) + ) .await; } async fn emit_tool_call_end( session: &Arc, - sub_id: &str, + turn: &TurnContext, call_id: &str, invocation: McpInvocation, duration: Duration, result: Result, ) { session - .send_event(Event { - id: sub_id.to_string(), - msg: EventMsg::McpToolCallEnd(McpToolCallEndEvent { + .send_event( + turn, + EventMsg::McpToolCallEnd(McpToolCallEndEvent { call_id: call_id.to_string(), invocation, duration, result, }), - }) + ) .await; } diff --git a/codex-rs/core/src/tools/handlers/plan.rs b/codex-rs/core/src/tools/handlers/plan.rs index 9291443f..ba8de6be 100644 --- a/codex-rs/core/src/tools/handlers/plan.rs +++ b/codex-rs/core/src/tools/handlers/plan.rs @@ -1,6 +1,7 @@ use crate::client_common::tools::ResponsesApiTool; use crate::client_common::tools::ToolSpec; use crate::codex::Session; +use crate::codex::TurnContext; use crate::function_tool::FunctionCallError; use crate::tools::context::ToolInvocation; use crate::tools::context::ToolOutput; @@ -10,7 +11,6 @@ use crate::tools::registry::ToolKind; use crate::tools::spec::JsonSchema; use async_trait::async_trait; use codex_protocol::plan_tool::UpdatePlanArgs; -use codex_protocol::protocol::Event; use codex_protocol::protocol::EventMsg; use std::collections::BTreeMap; use std::sync::LazyLock; @@ -68,7 +68,7 @@ impl ToolHandler for PlanHandler { async fn handle(&self, invocation: ToolInvocation) -> Result { let ToolInvocation { session, - sub_id, + turn, call_id, payload, .. @@ -84,7 +84,7 @@ impl ToolHandler for PlanHandler { }; let content = - handle_update_plan(session.as_ref(), arguments, sub_id.clone(), call_id).await?; + handle_update_plan(session.as_ref(), turn.as_ref(), arguments, call_id).await?; Ok(ToolOutput::Function { content, @@ -98,16 +98,13 @@ impl ToolHandler for PlanHandler { /// than forcing it to come up and document a plan (TBD how that affects performance). pub(crate) async fn handle_update_plan( session: &Session, + turn_context: &TurnContext, arguments: String, - sub_id: String, _call_id: String, ) -> Result { let args = parse_update_plan_arguments(&arguments)?; session - .send_event(Event { - id: sub_id.to_string(), - msg: EventMsg::PlanUpdate(args), - }) + .send_event(turn_context, EventMsg::PlanUpdate(args)) .await; Ok("Plan updated".to_string()) } diff --git a/codex-rs/core/src/tools/handlers/shell.rs b/codex-rs/core/src/tools/handlers/shell.rs index 168b8cd0..e19a40be 100644 --- a/codex-rs/core/src/tools/handlers/shell.rs +++ b/codex-rs/core/src/tools/handlers/shell.rs @@ -47,7 +47,6 @@ impl ToolHandler for ShellHandler { session, turn, tracker, - sub_id, call_id, tool_name, payload, @@ -68,7 +67,6 @@ impl ToolHandler for ShellHandler { Arc::clone(&session), Arc::clone(&turn), Arc::clone(&tracker), - sub_id.clone(), call_id.clone(), ) .await?; @@ -85,7 +83,6 @@ impl ToolHandler for ShellHandler { Arc::clone(&session), Arc::clone(&turn), Arc::clone(&tracker), - sub_id.clone(), call_id.clone(), ) .await?; diff --git a/codex-rs/core/src/tools/handlers/unified_exec.rs b/codex-rs/core/src/tools/handlers/unified_exec.rs index d171a1f7..2c238909 100644 --- a/codex-rs/core/src/tools/handlers/unified_exec.rs +++ b/codex-rs/core/src/tools/handlers/unified_exec.rs @@ -37,7 +37,6 @@ impl ToolHandler for UnifiedExecHandler { let ToolInvocation { session, turn, - sub_id, call_id, tool_name: _tool_name, payload, @@ -91,7 +90,6 @@ impl ToolHandler for UnifiedExecHandler { crate::unified_exec::UnifiedExecContext { session: &session, turn: turn.as_ref(), - sub_id: &sub_id, call_id: &call_id, session_id: parsed_session_id, }, diff --git a/codex-rs/core/src/tools/handlers/view_image.rs b/codex-rs/core/src/tools/handlers/view_image.rs index 43cada47..b25642d8 100644 --- a/codex-rs/core/src/tools/handlers/view_image.rs +++ b/codex-rs/core/src/tools/handlers/view_image.rs @@ -3,7 +3,6 @@ use serde::Deserialize; use tokio::fs; use crate::function_tool::FunctionCallError; -use crate::protocol::Event; use crate::protocol::EventMsg; use crate::protocol::ViewImageToolCallEvent; use crate::tools::context::ToolInvocation; @@ -31,7 +30,6 @@ impl ToolHandler for ViewImageHandler { session, turn, payload, - sub_id, call_id, .. } = invocation; @@ -76,13 +74,13 @@ impl ToolHandler for ViewImageHandler { })?; session - .send_event(Event { - id: sub_id.to_string(), - msg: EventMsg::ViewImageToolCall(ViewImageToolCallEvent { + .send_event( + turn.as_ref(), + EventMsg::ViewImageToolCall(ViewImageToolCallEvent { call_id, path: event_path, }), - }) + ) .await; Ok(ToolOutput::Function { diff --git a/codex-rs/core/src/tools/mod.rs b/codex-rs/core/src/tools/mod.rs index c0e9a90f..d11caf8e 100644 --- a/codex-rs/core/src/tools/mod.rs +++ b/codex-rs/core/src/tools/mod.rs @@ -61,7 +61,6 @@ pub(crate) async fn handle_container_exec_with_params( sess: Arc, turn_context: Arc, turn_diff_tracker: SharedTurnDiffTracker, - sub_id: String, call_id: String, ) -> Result { let _otel_event_manager = turn_context.client.get_otel_event_manager(); @@ -78,14 +77,8 @@ pub(crate) async fn handle_container_exec_with_params( // check if this was a patch, and apply it if so let apply_patch_exec = match maybe_parse_apply_patch_verified(¶ms.command, ¶ms.cwd) { MaybeApplyPatchVerified::Body(changes) => { - match apply_patch::apply_patch( - sess.as_ref(), - turn_context.as_ref(), - &sub_id, - &call_id, - changes, - ) - .await + match apply_patch::apply_patch(sess.as_ref(), turn_context.as_ref(), &call_id, changes) + .await { InternalApplyPatchInvocation::Output(item) => return item, InternalApplyPatchInvocation::DelegateToExec(apply_patch_exec) => { @@ -122,7 +115,7 @@ pub(crate) async fn handle_container_exec_with_params( ), }; - let event_ctx = ToolEventCtx::new(sess.as_ref(), &sub_id, &call_id, diff_opt); + let event_ctx = ToolEventCtx::new(sess.as_ref(), turn_context.as_ref(), &call_id, diff_opt); event_emitter.emit(event_ctx, ToolEventStage::Begin).await; // Build runtime contexts only when needed (shell/apply_patch below). @@ -141,7 +134,7 @@ pub(crate) async fn handle_container_exec_with_params( let mut runtime = ApplyPatchRuntime::new(); let tool_ctx = ToolCtx { session: sess.as_ref(), - sub_id: sub_id.clone(), + turn: turn_context.as_ref(), call_id: call_id.clone(), tool_name: tool_name.to_string(), }; @@ -172,7 +165,7 @@ pub(crate) async fn handle_container_exec_with_params( let mut runtime = ShellRuntime::new(); let tool_ctx = ToolCtx { session: sess.as_ref(), - sub_id: sub_id.clone(), + turn: turn_context.as_ref(), call_id: call_id.clone(), tool_name: tool_name.to_string(), }; diff --git a/codex-rs/core/src/tools/orchestrator.rs b/codex-rs/core/src/tools/orchestrator.rs index 4d061633..bdc4e3af 100644 --- a/codex-rs/core/src/tools/orchestrator.rs +++ b/codex-rs/core/src/tools/orchestrator.rs @@ -53,7 +53,7 @@ impl ToolOrchestrator { if needs_initial_approval { let approval_ctx = ApprovalCtx { session: tool_ctx.session, - sub_id: &tool_ctx.sub_id, + turn: turn_ctx, call_id: &tool_ctx.call_id, retry_reason: None, }; @@ -110,7 +110,7 @@ impl ToolOrchestrator { let reason_msg = build_denial_reason_from_output(output.as_ref()); let approval_ctx = ApprovalCtx { session: tool_ctx.session, - sub_id: &tool_ctx.sub_id, + turn: turn_ctx, call_id: &tool_ctx.call_id, retry_reason: Some(reason_msg), }; diff --git a/codex-rs/core/src/tools/parallel.rs b/codex-rs/core/src/tools/parallel.rs index 26dfed8e..eae181c1 100644 --- a/codex-rs/core/src/tools/parallel.rs +++ b/codex-rs/core/src/tools/parallel.rs @@ -18,7 +18,6 @@ pub(crate) struct ToolCallRuntime { session: Arc, turn_context: Arc, tracker: SharedTurnDiffTracker, - sub_id: String, parallel_execution: Arc>, } @@ -28,14 +27,12 @@ impl ToolCallRuntime { session: Arc, turn_context: Arc, tracker: SharedTurnDiffTracker, - sub_id: String, ) -> Self { Self { router, session, turn_context, tracker, - sub_id, parallel_execution: Arc::new(RwLock::new(())), } } @@ -50,7 +47,6 @@ impl ToolCallRuntime { let session = Arc::clone(&self.session); let turn = Arc::clone(&self.turn_context); let tracker = Arc::clone(&self.tracker); - let sub_id = self.sub_id.clone(); let lock = Arc::clone(&self.parallel_execution); let handle: AbortOnDropHandle> = @@ -62,7 +58,7 @@ impl ToolCallRuntime { }; router - .dispatch_tool_call(session, turn, tracker, sub_id, call) + .dispatch_tool_call(session, turn, tracker, call) .await })); diff --git a/codex-rs/core/src/tools/router.rs b/codex-rs/core/src/tools/router.rs index fa6e38a4..161997fb 100644 --- a/codex-rs/core/src/tools/router.rs +++ b/codex-rs/core/src/tools/router.rs @@ -134,7 +134,6 @@ impl ToolRouter { session: Arc, turn: Arc, tracker: SharedTurnDiffTracker, - sub_id: String, call: ToolCall, ) -> Result { let ToolCall { @@ -149,7 +148,6 @@ impl ToolRouter { session, turn, tracker, - sub_id, call_id, tool_name, payload, diff --git a/codex-rs/core/src/tools/runtimes/apply_patch.rs b/codex-rs/core/src/tools/runtimes/apply_patch.rs index f6062a3a..eb1cda4e 100644 --- a/codex-rs/core/src/tools/runtimes/apply_patch.rs +++ b/codex-rs/core/src/tools/runtimes/apply_patch.rs @@ -68,7 +68,7 @@ impl ApplyPatchRuntime { fn stdout_stream(ctx: &ToolCtx<'_>) -> Option { Some(crate::exec::StdoutStream { - sub_id: ctx.sub_id.clone(), + sub_id: ctx.turn.sub_id.clone(), call_id: ctx.call_id.clone(), tx_event: ctx.session.get_tx_event(), }) @@ -101,7 +101,7 @@ impl Approvable for ApplyPatchRuntime { ) -> BoxFuture<'a, ReviewDecision> { let key = self.approval_key(req); let session = ctx.session; - let sub_id = ctx.sub_id.to_string(); + let turn = ctx.turn; let call_id = ctx.call_id.to_string(); let cwd = req.cwd.clone(); let retry_reason = ctx.retry_reason.clone(); @@ -111,7 +111,7 @@ impl Approvable for ApplyPatchRuntime { if let Some(reason) = retry_reason { session .request_command_approval( - sub_id, + turn, call_id, vec!["apply_patch".to_string()], cwd, diff --git a/codex-rs/core/src/tools/runtimes/shell.rs b/codex-rs/core/src/tools/runtimes/shell.rs index 313e9e07..bfc2114f 100644 --- a/codex-rs/core/src/tools/runtimes/shell.rs +++ b/codex-rs/core/src/tools/runtimes/shell.rs @@ -51,7 +51,7 @@ impl ShellRuntime { fn stdout_stream(ctx: &ToolCtx<'_>) -> Option { Some(crate::exec::StdoutStream { - sub_id: ctx.sub_id.clone(), + sub_id: ctx.turn.sub_id.clone(), call_id: ctx.call_id.clone(), tx_event: ctx.session.get_tx_event(), }) @@ -91,12 +91,12 @@ impl Approvable for ShellRuntime { .clone() .or_else(|| req.justification.clone()); let session = ctx.session; - let sub_id = ctx.sub_id.to_string(); + let turn = ctx.turn; let call_id = ctx.call_id.to_string(); Box::pin(async move { with_cached_approval(&session.services, key, || async move { session - .request_command_approval(sub_id, call_id, command, cwd, reason) + .request_command_approval(turn, call_id, command, cwd, reason) .await }) .await diff --git a/codex-rs/core/src/tools/runtimes/unified_exec.rs b/codex-rs/core/src/tools/runtimes/unified_exec.rs index e6421055..c7d136eb 100644 --- a/codex-rs/core/src/tools/runtimes/unified_exec.rs +++ b/codex-rs/core/src/tools/runtimes/unified_exec.rs @@ -80,7 +80,7 @@ impl Approvable for UnifiedExecRuntime<'_> { ) -> BoxFuture<'b, ReviewDecision> { let key = self.approval_key(req); let session = ctx.session; - let sub_id = ctx.sub_id.to_string(); + let turn = ctx.turn; let call_id = ctx.call_id.to_string(); let command = req.command.clone(); let cwd = req.cwd.clone(); @@ -88,7 +88,7 @@ impl Approvable for UnifiedExecRuntime<'_> { Box::pin(async move { with_cached_approval(&session.services, key, || async move { session - .request_command_approval(sub_id, call_id, command, cwd, reason) + .request_command_approval(turn, call_id, command, cwd, reason) .await }) .await diff --git a/codex-rs/core/src/tools/sandboxing.rs b/codex-rs/core/src/tools/sandboxing.rs index ed142da4..7c4d65ca 100644 --- a/codex-rs/core/src/tools/sandboxing.rs +++ b/codex-rs/core/src/tools/sandboxing.rs @@ -5,6 +5,7 @@ //! and helpers (`Sandboxable`, `ToolRuntime`, `SandboxAttempt`, etc.). use crate::codex::Session; +use crate::codex::TurnContext; use crate::error::CodexErr; use crate::protocol::SandboxPolicy; use crate::sandboxing::CommandSpec; @@ -77,7 +78,7 @@ where #[derive(Clone)] pub(crate) struct ApprovalCtx<'a> { pub session: &'a Session, - pub sub_id: &'a str, + pub turn: &'a TurnContext, pub call_id: &'a str, pub retry_reason: Option, } @@ -145,7 +146,7 @@ pub(crate) trait Sandboxable { pub(crate) struct ToolCtx<'a> { pub session: &'a Session, - pub sub_id: String, + pub turn: &'a TurnContext, pub call_id: String, pub tool_name: String, } diff --git a/codex-rs/core/src/unified_exec/mod.rs b/codex-rs/core/src/unified_exec/mod.rs index fb791fe8..7118e20e 100644 --- a/codex-rs/core/src/unified_exec/mod.rs +++ b/codex-rs/core/src/unified_exec/mod.rs @@ -43,7 +43,6 @@ const UNIFIED_EXEC_OUTPUT_MAX_BYTES: usize = 128 * 1024; // 128 KiB pub(crate) struct UnifiedExecContext<'a> { pub session: &'a Session, pub turn: &'a TurnContext, - pub sub_id: &'a str, pub call_id: &'a str, pub session_id: Option, } @@ -110,7 +109,6 @@ mod tests { UnifiedExecContext { session, turn: turn.as_ref(), - sub_id: "sub", call_id: "call", session_id, }, diff --git a/codex-rs/core/src/unified_exec/session_manager.rs b/codex-rs/core/src/unified_exec/session_manager.rs index 8bc8cb29..83b076d9 100644 --- a/codex-rs/core/src/unified_exec/session_manager.rs +++ b/codex-rs/core/src/unified_exec/session_manager.rs @@ -113,7 +113,7 @@ impl UnifiedExecSessionManager { ); let tool_ctx = ToolCtx { session: context.session, - sub_id: context.sub_id.to_string(), + turn: context.turn, call_id: context.call_id.to_string(), tool_name: "unified_exec".to_string(), };