diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index 84a04d75..48aa3dbd 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -1,6 +1,5 @@ use std::borrow::Cow; use std::collections::HashMap; -use std::collections::HashSet; use std::path::Path; use std::path::PathBuf; use std::sync::Arc; @@ -111,7 +110,6 @@ use crate::protocol::Submission; use crate::protocol::TaskCompleteEvent; use crate::protocol::TokenCountEvent; use crate::protocol::TokenUsage; -use crate::protocol::TokenUsageInfo; use crate::protocol::TurnDiffEvent; use crate::protocol::WebSearchBeginEvent; use crate::rollout::RolloutRecorder; @@ -120,6 +118,8 @@ use crate::safety::SafetyCheck; use crate::safety::assess_command_safety; use crate::safety::assess_safety_for_untrusted_command; use crate::shell; +use crate::state::ActiveTurn; +use crate::state::SessionServices; use crate::turn_diff_tracker::TurnDiffTracker; use crate::unified_exec::UnifiedExecSessionManager; use crate::user_instructions::UserInstructions; @@ -253,17 +253,7 @@ impl Codex { } } -/// Mutable state of the agent -#[derive(Default)] -struct State { - approved_commands: HashSet>, - current_task: Option, - pending_approvals: HashMap>, - pending_input: Vec, - history: ConversationHistory, - token_info: Option, - latest_rate_limits: Option, -} +use crate::state::SessionState; /// Context for an initialized model agent /// @@ -271,21 +261,9 @@ struct State { pub(crate) struct Session { conversation_id: ConversationId, tx_event: Sender, - - /// Manager for external MCP servers/tools. - mcp_connection_manager: McpConnectionManager, - session_manager: ExecSessionManager, - unified_exec_manager: UnifiedExecSessionManager, - - notifier: UserNotifier, - - /// Optional rollout recorder for persisting the conversation transcript so - /// sessions can be replayed or inspected later. - rollout: Mutex>, - state: Mutex, - codex_linux_sandbox_exe: Option, - user_shell: shell::Shell, - show_raw_agent_reasoning: bool, + state: Mutex, + active_turn: Mutex>, + services: SessionServices, next_internal_sub_id: AtomicU64, } @@ -413,10 +391,7 @@ impl Session { })?; let rollout_path = rollout_recorder.rollout_path.clone(); // Create the mutable state for the Session. - let state = State { - history: ConversationHistory::new(), - ..Default::default() - }; + let state = SessionState::new(); // Handle MCP manager result and record any startup failures. let (mcp_connection_manager, failed_clients) = match mcp_res { @@ -474,18 +449,23 @@ impl Session { is_review_mode: false, final_output_json_schema: None, }; - let sess = Arc::new(Session { - conversation_id, - tx_event: tx_event.clone(), + let services = SessionServices { mcp_connection_manager, session_manager: ExecSessionManager::default(), unified_exec_manager: UnifiedExecSessionManager::default(), notifier: notify, - state: Mutex::new(state), rollout: Mutex::new(Some(rollout_recorder)), codex_linux_sandbox_exe: config.codex_linux_sandbox_exe.clone(), user_shell: default_shell, show_raw_agent_reasoning: config.show_raw_agent_reasoning, + }; + + let sess = Arc::new(Session { + conversation_id, + tx_event: tx_event.clone(), + state: Mutex::new(state), + active_turn: Mutex::new(None), + services, next_internal_sub_id: AtomicU64::new(0), }); @@ -521,6 +501,15 @@ impl Session { current_task.abort(TurnAbortReason::Replaced); } state.current_task = Some(task); + if let Some(current_task) = &state.current_task { + let mut active = self.active_turn.lock().await; + *active = Some(ActiveTurn { + sub_id: current_task.sub_id.clone(), + turn_state: std::sync::Arc::new(tokio::sync::Mutex::new( + crate::state::TurnState::default(), + )), + }); + } } pub async fn remove_task(&self, sub_id: &str) { @@ -530,6 +519,12 @@ impl Session { { state.current_task.take(); } + let mut active = self.active_turn.lock().await; + if let Some(at) = &*active + && at.sub_id == sub_id + { + *active = None; + } } fn next_internal_sub_id(&self) -> String { @@ -591,8 +586,14 @@ impl Session { let (tx_approve, rx_approve) = oneshot::channel(); let event_id = sub_id.clone(); let prev_entry = { - let mut state = self.state.lock().await; - state.pending_approvals.insert(sub_id, tx_approve) + let mut active = self.active_turn.lock().await; + match active.as_mut() { + Some(at) => { + let mut ts = at.turn_state.lock().await; + ts.insert_pending_approval(sub_id, tx_approve) + } + None => None, + } }; if prev_entry.is_some() { warn!("Overwriting existing pending approval for sub_id: {event_id}"); @@ -623,8 +624,14 @@ impl Session { let (tx_approve, rx_approve) = oneshot::channel(); let event_id = sub_id.clone(); let prev_entry = { - let mut state = self.state.lock().await; - state.pending_approvals.insert(sub_id, tx_approve) + let mut active = self.active_turn.lock().await; + match active.as_mut() { + Some(at) => { + let mut ts = at.turn_state.lock().await; + ts.insert_pending_approval(sub_id, tx_approve) + } + None => None, + } }; if prev_entry.is_some() { warn!("Overwriting existing pending approval for sub_id: {event_id}"); @@ -645,8 +652,14 @@ impl Session { pub async fn notify_approval(&self, sub_id: &str, decision: ReviewDecision) { let entry = { - let mut state = self.state.lock().await; - state.pending_approvals.remove(sub_id) + let mut active = self.active_turn.lock().await; + match active.as_mut() { + Some(at) => { + let mut ts = at.turn_state.lock().await; + ts.remove_pending_approval(sub_id) + } + None => None, + } }; match entry { Some(tx_approve) => { @@ -660,7 +673,7 @@ impl Session { pub async fn add_approved_command(&self, cmd: Vec) { let mut state = self.state.lock().await; - state.approved_commands.insert(cmd); + state.add_approved_command(cmd); } /// Records input items: always append to conversation history and @@ -700,7 +713,12 @@ impl Session { /// Append ResponseItems to the in-memory conversation history only. async fn record_into_history(&self, items: &[ResponseItem]) { let mut state = self.state.lock().await; - state.history.record_items(items.iter()); + state.record_items(items.iter()); + } + + async fn replace_history(&self, items: Vec) { + let mut state = self.state.lock().await; + state.replace_history(items); } async fn persist_rollout_response_items(&self, items: &[ResponseItem]) { @@ -721,14 +739,14 @@ impl Session { Some(turn_context.cwd.clone()), Some(turn_context.approval_policy), Some(turn_context.sandbox_policy.clone()), - Some(self.user_shell.clone()), + Some(self.user_shell().clone()), ))); items } async fn persist_rollout_items(&self, items: &[RolloutItem]) { let recorder = { - let guard = self.rollout.lock().await; + let guard = self.services.rollout.lock().await; guard.clone() }; if let Some(rec) = recorder @@ -738,6 +756,11 @@ impl Session { } } + pub(crate) async fn history_snapshot(&self) -> Vec { + let state = self.state.lock().await; + state.history_snapshot() + } + async fn update_token_usage_info( &self, sub_id: &str, @@ -747,12 +770,10 @@ impl Session { { let mut state = self.state.lock().await; if let Some(token_usage) = token_usage { - let info = TokenUsageInfo::new_or_append( - &state.token_info, - &Some(token_usage.clone()), + state.update_token_info_from_usage( + token_usage, turn_context.client.get_model_context_window(), ); - state.token_info = info; } } self.send_token_count_event(sub_id).await; @@ -761,7 +782,7 @@ impl Session { async fn update_rate_limits(&self, sub_id: &str, new_rate_limits: RateLimitSnapshot) { { let mut state = self.state.lock().await; - state.latest_rate_limits = Some(new_rate_limits); + state.set_rate_limits(new_rate_limits); } self.send_token_count_event(sub_id).await; } @@ -769,7 +790,7 @@ impl Session { async fn send_token_count_event(&self, sub_id: &str) { let (info, rate_limits) = { let state = self.state.lock().await; - (state.token_info.clone(), state.latest_rate_limits.clone()) + state.token_info_and_rate_limits() }; let event = Event { id: sub_id.to_string(), @@ -788,7 +809,7 @@ impl Session { // Derive user message events and persist only UserMessage to rollout let msgs = - map_response_item_to_event_messages(&response_item, self.show_raw_agent_reasoning); + map_response_item_to_event_messages(&response_item, self.show_raw_agent_reasoning()); let user_msgs: Vec = msgs .into_iter() .filter_map(|m| match m { @@ -987,16 +1008,20 @@ impl Session { pub async fn turn_input_with_history(&self, extra: Vec) -> Vec { let history = { let state = self.state.lock().await; - state.history.contents() + state.history_snapshot() }; [history, extra].concat() } /// Returns the input if there was no task running to inject into pub async fn inject_input(&self, input: Vec) -> Result<(), Vec> { - let mut state = self.state.lock().await; + let state = self.state.lock().await; if state.current_task.is_some() { - state.pending_input.push(input.into()); + let mut active = self.active_turn.lock().await; + if let Some(at) = active.as_mut() { + let mut ts = at.turn_state.lock().await; + ts.push_pending_input(input.into()); + } Ok(()) } else { Err(input) @@ -1004,13 +1029,12 @@ impl Session { } pub async fn get_pending_input(&self) -> Vec { - let mut state = self.state.lock().await; - if state.pending_input.is_empty() { - Vec::with_capacity(0) + let mut active = self.active_turn.lock().await; + if let Some(at) = active.as_mut() { + let mut ts = at.turn_state.lock().await; + ts.take_pending_input() } else { - let mut ret = Vec::new(); - std::mem::swap(&mut ret, &mut state.pending_input); - ret + Vec::with_capacity(0) } } @@ -1020,7 +1044,8 @@ impl Session { tool: &str, arguments: Option, ) -> anyhow::Result { - self.mcp_connection_manager + self.services + .mcp_connection_manager .call_tool(server, tool, arguments) .await } @@ -1028,8 +1053,11 @@ impl Session { pub async fn interrupt_task(&self) { info!("interrupt received: abort current task, if any"); let mut state = self.state.lock().await; - state.pending_approvals.clear(); - state.pending_input.clear(); + let mut active = self.active_turn.lock().await; + if let Some(at) = active.as_mut() { + let mut ts = at.turn_state.lock().await; + ts.clear_pending(); + } if let Some(task) = state.current_task.take() { task.abort(TurnAbortReason::Interrupted); } @@ -1037,8 +1065,12 @@ impl Session { fn interrupt_task_sync(&self) { if let Ok(mut state) = self.state.try_lock() { - state.pending_approvals.clear(); - state.pending_input.clear(); + if let Ok(mut active) = self.active_turn.try_lock() + && let Some(at) = active.as_mut() + && let Ok(mut ts) = at.turn_state.try_lock() + { + ts.clear_pending(); + } if let Some(task) = state.current_task.take() { task.abort(TurnAbortReason::Interrupted); } @@ -1046,7 +1078,15 @@ impl Session { } pub(crate) fn notifier(&self) -> &UserNotifier { - &self.notifier + &self.services.notifier + } + + fn user_shell(&self) -> &shell::Shell { + &self.services.user_shell + } + + fn show_raw_agent_reasoning(&self) -> bool { + self.services.show_raw_agent_reasoning } } @@ -1152,15 +1192,19 @@ impl AgentTask { // TOCTOU? if !self.handle.is_finished() { self.handle.abort(); + let sub_id = self.sub_id.clone(); + let is_review = self.kind == AgentTaskKind::Review; + let sess = self.sess; let event = Event { - id: self.sub_id.clone(), + id: sub_id.clone(), msg: EventMsg::TurnAborted(TurnAbortedEvent { reason }), }; - let sess = self.sess; tokio::spawn(async move { - if self.kind == AgentTaskKind::Review { - exit_review_mode(sess.clone(), self.sub_id, None).await; + if is_review { + exit_review_mode(sess.clone(), sub_id.clone(), None).await; } + // Ensure active turn state is cleared when a task is aborted. + sess.remove_task(&sub_id).await; sess.send_event(event).await; }); } @@ -1418,7 +1462,7 @@ async fn submission_loop( let sub_id = sub.id.clone(); // This is a cheap lookup from the connection manager's cache. - let tools = sess.mcp_connection_manager.list_all_tools(); + let tools = sess.services.mcp_connection_manager.list_all_tools(); let event = Event { id: sub_id, msg: EventMsg::McpListToolsResponse( @@ -1468,7 +1512,7 @@ async fn submission_loop( // Gracefully flush and shutdown rollout recorder on session end so tests // that inspect the rollout file do not race with the background writer. let recorder_opt = { - let mut guard = sess.rollout.lock().await; + let mut guard = sess.services.rollout.lock().await; guard.take() }; if let Some(rec) = recorder_opt @@ -1495,7 +1539,7 @@ async fn submission_loop( let sub_id = sub.id.clone(); // Flush rollout writes before returning the path so readers observe a consistent file. let (path, rec_opt) = { - let guard = sess.rollout.lock().await; + let guard = sess.services.rollout.lock().await; match guard.as_ref() { Some(rec) => (rec.get_rollout_path(), Some(rec.clone())), None => { @@ -1953,7 +1997,7 @@ async fn run_turn( ) -> CodexResult { let tools = get_openai_tools( &turn_context.tools_config, - Some(sess.mcp_connection_manager.list_all_tools()), + Some(sess.services.mcp_connection_manager.list_all_tools()), ); let prompt = Prompt { @@ -2203,7 +2247,7 @@ async fn try_run_turn( sess.send_event(event).await; } ResponseEvent::ReasoningContentDelta(delta) => { - if sess.show_raw_agent_reasoning { + if sess.show_raw_agent_reasoning() { let event = Event { id: sub_id.to_string(), msg: EventMsg::AgentReasoningRawContentDelta( @@ -2233,7 +2277,9 @@ async fn handle_response_item( .. } => { info!("FunctionCall: {name}({arguments})"); - if let Some((server, tool_name)) = sess.mcp_connection_manager.parse_tool_name(&name) { + if let Some((server, tool_name)) = + sess.services.mcp_connection_manager.parse_tool_name(&name) + { let resp = handle_mcp_tool_call( sess, sub_id, @@ -2369,7 +2415,7 @@ async fn handle_response_item( trace!("suppressing assistant Message in review mode"); Vec::new() } - _ => map_response_item_to_event_messages(&item, sess.show_raw_agent_reasoning), + _ => map_response_item_to_event_messages(&item, sess.show_raw_agent_reasoning()), }; for msg in msgs { let event = Event { @@ -2411,6 +2457,7 @@ async fn handle_unified_exec_tool_call( }; let value = sess + .services .unified_exec_manager .handle_request(request) .await @@ -2529,6 +2576,7 @@ async fn handle_function_call( )) })?; let result = sess + .services .session_manager .handle_exec_command_request(exec_params) .await; @@ -2546,6 +2594,7 @@ async fn handle_function_call( })?; let result = sess + .services .session_manager .handle_write_stdin_request(write_stdin_params) .await @@ -2636,12 +2685,12 @@ fn maybe_translate_shell_command( sess: &Session, turn_context: &TurnContext, ) -> ExecParams { - let should_translate = matches!(sess.user_shell, crate::shell::Shell::PowerShell(_)) + let should_translate = matches!(sess.user_shell(), crate::shell::Shell::PowerShell(_)) || turn_context.shell_environment_policy.use_profile; if should_translate && let Some(command) = sess - .user_shell + .user_shell() .format_default_shell_invocation(params.command.clone()) { return ExecParams { command, ..params }; @@ -2741,7 +2790,7 @@ async fn handle_container_exec_with_params( ¶ms.command, turn_context.approval_policy, &turn_context.sandbox_policy, - &state.approved_commands, + state.approved_commands_ref(), params.with_escalated_permissions.unwrap_or(false), ) }; @@ -2812,7 +2861,7 @@ async fn handle_container_exec_with_params( sandbox_type, sandbox_policy: &turn_context.sandbox_policy, sandbox_cwd: &turn_context.cwd, - codex_linux_sandbox_exe: &sess.codex_linux_sandbox_exe, + codex_linux_sandbox_exe: &sess.services.codex_linux_sandbox_exe, stdout_stream: if exec_command_context.apply_patch.is_some() { None } else { @@ -2927,7 +2976,7 @@ async fn handle_sandbox_error( sandbox_type: SandboxType::None, sandbox_policy: &turn_context.sandbox_policy, sandbox_cwd: &turn_context.cwd, - codex_linux_sandbox_exe: &sess.codex_linux_sandbox_exe, + codex_linux_sandbox_exe: &sess.services.codex_linux_sandbox_exe, stdout_stream: if exec_command_context.apply_patch.is_some() { None } else { @@ -3268,7 +3317,7 @@ mod tests { }), )); - let actual = tokio_test::block_on(async { session.state.lock().await.history.contents() }); + let actual = tokio_test::block_on(async { session.state.lock().await.history_snapshot() }); assert_eq!(expected, actual); } @@ -3281,7 +3330,7 @@ mod tests { session.record_initial_history(&turn_context, InitialHistory::Forked(rollout_items)), ); - let actual = tokio_test::block_on(async { session.state.lock().await.history.contents() }); + let actual = tokio_test::block_on(async { session.state.lock().await.history_snapshot() }); assert_eq!(expected, actual); } @@ -3507,21 +3556,22 @@ mod tests { is_review_mode: false, final_output_json_schema: None, }; - let session = Session { - conversation_id, - tx_event, + let services = SessionServices { mcp_connection_manager: McpConnectionManager::default(), session_manager: ExecSessionManager::default(), unified_exec_manager: UnifiedExecSessionManager::default(), notifier: UserNotifier::default(), rollout: Mutex::new(None), - state: Mutex::new(State { - history: ConversationHistory::new(), - ..Default::default() - }), codex_linux_sandbox_exe: None, user_shell: shell::Shell::Unknown, show_raw_agent_reasoning: config.show_raw_agent_reasoning, + }; + let session = Session { + conversation_id, + tx_event, + state: Mutex::new(SessionState::new()), + active_turn: Mutex::new(None), + services, next_internal_sub_id: AtomicU64::new(0), }; (session, turn_context) diff --git a/codex-rs/core/src/codex/compact.rs b/codex-rs/core/src/codex/compact.rs index 8f213d4e..4facd45d 100644 --- a/codex-rs/core/src/codex/compact.rs +++ b/codex-rs/core/src/codex/compact.rs @@ -151,18 +151,12 @@ async fn run_compact_task_inner( if remove_task_on_completion { sess.remove_task(&sub_id).await; } - let history_snapshot = { - let state = sess.state.lock().await; - state.history.contents() - }; + let history_snapshot = sess.history_snapshot().await; let summary_text = get_last_assistant_message_from_turn(&history_snapshot).unwrap_or_default(); let user_messages = collect_user_messages(&history_snapshot); let initial_context = sess.build_initial_context(turn_context.as_ref()); let new_history = build_compacted_history(initial_context, &user_messages, &summary_text); - { - let mut state = sess.state.lock().await; - state.history.replace(new_history); - } + sess.replace_history(new_history).await; let rollout_item = RolloutItem::Compacted(CompactedItem { message: summary_text.clone(), @@ -270,8 +264,7 @@ async fn drain_to_completed( }; match event { Ok(ResponseEvent::OutputItemDone(item)) => { - let mut state = sess.state.lock().await; - state.history.record_items(std::slice::from_ref(&item)); + sess.record_into_history(std::slice::from_ref(&item)).await; } Ok(ResponseEvent::Completed { .. }) => { return Ok(()); diff --git a/codex-rs/core/src/lib.rs b/codex-rs/core/src/lib.rs index 2db1e6e7..36287c1a 100644 --- a/codex-rs/core/src/lib.rs +++ b/codex-rs/core/src/lib.rs @@ -76,6 +76,7 @@ pub use rollout::list::ConversationItem; pub use rollout::list::ConversationsPage; pub use rollout::list::Cursor; mod function_tool; +mod state; mod user_notification; pub mod util; diff --git a/codex-rs/core/src/state/mod.rs b/codex-rs/core/src/state/mod.rs new file mode 100644 index 00000000..927f5981 --- /dev/null +++ b/codex-rs/core/src/state/mod.rs @@ -0,0 +1,8 @@ +mod service; +mod session; +mod turn; + +pub(crate) use service::SessionServices; +pub(crate) use session::SessionState; +pub(crate) use turn::ActiveTurn; +pub(crate) use turn::TurnState; diff --git a/codex-rs/core/src/state/service.rs b/codex-rs/core/src/state/service.rs new file mode 100644 index 00000000..a67b9dda --- /dev/null +++ b/codex-rs/core/src/state/service.rs @@ -0,0 +1,18 @@ +use crate::RolloutRecorder; +use crate::exec_command::ExecSessionManager; +use crate::mcp_connection_manager::McpConnectionManager; +use crate::unified_exec::UnifiedExecSessionManager; +use crate::user_notification::UserNotifier; +use std::path::PathBuf; +use tokio::sync::Mutex; + +pub(crate) struct SessionServices { + pub(crate) mcp_connection_manager: McpConnectionManager, + pub(crate) session_manager: ExecSessionManager, + pub(crate) unified_exec_manager: UnifiedExecSessionManager, + pub(crate) notifier: UserNotifier, + pub(crate) rollout: Mutex>, + pub(crate) codex_linux_sandbox_exe: Option, + pub(crate) user_shell: crate::shell::Shell, + pub(crate) show_raw_agent_reasoning: bool, +} diff --git a/codex-rs/core/src/state/session.rs b/codex-rs/core/src/state/session.rs new file mode 100644 index 00000000..f8afbdec --- /dev/null +++ b/codex-rs/core/src/state/session.rs @@ -0,0 +1,82 @@ +//! Session-wide mutable state. + +use std::collections::HashSet; + +use codex_protocol::models::ResponseItem; + +use crate::codex::AgentTask; +use crate::conversation_history::ConversationHistory; +use crate::protocol::RateLimitSnapshot; +use crate::protocol::TokenUsage; +use crate::protocol::TokenUsageInfo; + +/// Persistent, session-scoped state previously stored directly on `Session`. +#[derive(Default)] +pub(crate) struct SessionState { + pub(crate) approved_commands: HashSet>, + pub(crate) current_task: Option, + pub(crate) history: ConversationHistory, + pub(crate) token_info: Option, + pub(crate) latest_rate_limits: Option, +} + +impl SessionState { + /// Create a new session state mirroring previous `State::default()` semantics. + pub(crate) fn new() -> Self { + Self { + history: ConversationHistory::new(), + ..Default::default() + } + } + + // History helpers + pub(crate) fn record_items(&mut self, items: I) + where + I: IntoIterator, + I::Item: std::ops::Deref, + { + self.history.record_items(items) + } + + pub(crate) fn history_snapshot(&self) -> Vec { + self.history.contents() + } + + pub(crate) fn replace_history(&mut self, items: Vec) { + self.history.replace(items); + } + + // Approved command helpers + pub(crate) fn add_approved_command(&mut self, cmd: Vec) { + self.approved_commands.insert(cmd); + } + + pub(crate) fn approved_commands_ref(&self) -> &HashSet> { + &self.approved_commands + } + + // Token/rate limit helpers + pub(crate) fn update_token_info_from_usage( + &mut self, + usage: &TokenUsage, + model_context_window: Option, + ) { + self.token_info = TokenUsageInfo::new_or_append( + &self.token_info, + &Some(usage.clone()), + model_context_window, + ); + } + + pub(crate) fn set_rate_limits(&mut self, snapshot: RateLimitSnapshot) { + self.latest_rate_limits = Some(snapshot); + } + + pub(crate) fn token_info_and_rate_limits( + &self, + ) -> (Option, Option) { + (self.token_info.clone(), self.latest_rate_limits.clone()) + } + + // Pending input/approval moved to TurnState. +} diff --git a/codex-rs/core/src/state/turn.rs b/codex-rs/core/src/state/turn.rs new file mode 100644 index 00000000..b49c86b5 --- /dev/null +++ b/codex-rs/core/src/state/turn.rs @@ -0,0 +1,60 @@ +//! Turn-scoped state and active turn metadata scaffolding. + +use std::collections::HashMap; +use std::sync::Arc; +use tokio::sync::Mutex; + +use codex_protocol::models::ResponseInputItem; +use tokio::sync::oneshot; + +use crate::protocol::ReviewDecision; + +/// Metadata about the currently running turn. +#[derive(Default)] +pub(crate) struct ActiveTurn { + pub(crate) sub_id: String, + pub(crate) turn_state: Arc>, +} + +/// Mutable state for a single turn. +#[derive(Default)] +pub(crate) struct TurnState { + pending_approvals: HashMap>, + pending_input: Vec, +} + +impl TurnState { + pub(crate) fn insert_pending_approval( + &mut self, + key: String, + tx: oneshot::Sender, + ) -> Option> { + self.pending_approvals.insert(key, tx) + } + + pub(crate) fn remove_pending_approval( + &mut self, + key: &str, + ) -> Option> { + self.pending_approvals.remove(key) + } + + pub(crate) fn clear_pending(&mut self) { + self.pending_approvals.clear(); + self.pending_input.clear(); + } + + pub(crate) fn push_pending_input(&mut self, input: ResponseInputItem) { + self.pending_input.push(input); + } + + pub(crate) fn take_pending_input(&mut self) -> Vec { + if self.pending_input.is_empty() { + Vec::with_capacity(0) + } else { + let mut ret = Vec::new(); + std::mem::swap(&mut ret, &mut self.pending_input); + ret + } + } +}