fix: introduce MutexExt::lock_unchecked() so we stop ignoring unwrap() throughout codex.rs (#2340)
This way we are sure a dangerous `unwrap()` does not sneak in! --- [//]: # (BEGIN SAPLING FOOTER) Stack created with [Sapling](https://sapling-scm.com). Best reviewed with [ReviewStack](https://reviewstack.dev/openai/codex/pull/2340). * #2345 * #2329 * #2343 * __->__ #2340 * #2338
This commit is contained in:
@@ -1,6 +1,3 @@
|
|||||||
// Poisoned mutex should fail the program
|
|
||||||
#![expect(clippy::unwrap_used)]
|
|
||||||
|
|
||||||
use std::borrow::Cow;
|
use std::borrow::Cow;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::collections::HashSet;
|
use std::collections::HashSet;
|
||||||
@@ -8,6 +5,7 @@ use std::path::Path;
|
|||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::sync::Mutex;
|
use std::sync::Mutex;
|
||||||
|
use std::sync::MutexGuard;
|
||||||
use std::sync::atomic::AtomicU64;
|
use std::sync::atomic::AtomicU64;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
|
|
||||||
@@ -108,6 +106,21 @@ use crate::turn_diff_tracker::TurnDiffTracker;
|
|||||||
use crate::user_notification::UserNotification;
|
use crate::user_notification::UserNotification;
|
||||||
use crate::util::backoff;
|
use crate::util::backoff;
|
||||||
|
|
||||||
|
// A convenience extension trait for acquiring mutex locks where poisoning is
|
||||||
|
// unrecoverable and should abort the program. This avoids scattered `.unwrap()`
|
||||||
|
// calls on `lock()` while still surfacing a clear panic message when a lock is
|
||||||
|
// poisoned.
|
||||||
|
trait MutexExt<T> {
|
||||||
|
fn lock_unchecked(&self) -> MutexGuard<'_, T>;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T> MutexExt<T> for Mutex<T> {
|
||||||
|
fn lock_unchecked(&self) -> MutexGuard<'_, T> {
|
||||||
|
#[expect(clippy::expect_used)]
|
||||||
|
self.lock().expect("poisoned lock")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// The high-level interface to the Codex system.
|
/// The high-level interface to the Codex system.
|
||||||
/// It operates as a queue pair where you send submissions and receive events.
|
/// It operates as a queue pair where you send submissions and receive events.
|
||||||
pub struct Codex {
|
pub struct Codex {
|
||||||
@@ -523,7 +536,7 @@ impl Session {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn set_task(&self, task: AgentTask) {
|
pub fn set_task(&self, task: AgentTask) {
|
||||||
let mut state = self.state.lock().unwrap();
|
let mut state = self.state.lock_unchecked();
|
||||||
if let Some(current_task) = state.current_task.take() {
|
if let Some(current_task) = state.current_task.take() {
|
||||||
current_task.abort();
|
current_task.abort();
|
||||||
}
|
}
|
||||||
@@ -531,7 +544,7 @@ impl Session {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn remove_task(&self, sub_id: &str) {
|
pub fn remove_task(&self, sub_id: &str) {
|
||||||
let mut state = self.state.lock().unwrap();
|
let mut state = self.state.lock_unchecked();
|
||||||
if let Some(task) = &state.current_task {
|
if let Some(task) = &state.current_task {
|
||||||
if task.sub_id == sub_id {
|
if task.sub_id == sub_id {
|
||||||
state.current_task.take();
|
state.current_task.take();
|
||||||
@@ -567,7 +580,7 @@ impl Session {
|
|||||||
};
|
};
|
||||||
let _ = self.tx_event.send(event).await;
|
let _ = self.tx_event.send(event).await;
|
||||||
{
|
{
|
||||||
let mut state = self.state.lock().unwrap();
|
let mut state = self.state.lock_unchecked();
|
||||||
state.pending_approvals.insert(sub_id, tx_approve);
|
state.pending_approvals.insert(sub_id, tx_approve);
|
||||||
}
|
}
|
||||||
rx_approve
|
rx_approve
|
||||||
@@ -593,21 +606,21 @@ impl Session {
|
|||||||
};
|
};
|
||||||
let _ = self.tx_event.send(event).await;
|
let _ = self.tx_event.send(event).await;
|
||||||
{
|
{
|
||||||
let mut state = self.state.lock().unwrap();
|
let mut state = self.state.lock_unchecked();
|
||||||
state.pending_approvals.insert(sub_id, tx_approve);
|
state.pending_approvals.insert(sub_id, tx_approve);
|
||||||
}
|
}
|
||||||
rx_approve
|
rx_approve
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn notify_approval(&self, sub_id: &str, decision: ReviewDecision) {
|
pub fn notify_approval(&self, sub_id: &str, decision: ReviewDecision) {
|
||||||
let mut state = self.state.lock().unwrap();
|
let mut state = self.state.lock_unchecked();
|
||||||
if let Some(tx_approve) = state.pending_approvals.remove(sub_id) {
|
if let Some(tx_approve) = state.pending_approvals.remove(sub_id) {
|
||||||
tx_approve.send(decision).ok();
|
tx_approve.send(decision).ok();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn add_approved_command(&self, cmd: Vec<String>) {
|
pub fn add_approved_command(&self, cmd: Vec<String>) {
|
||||||
let mut state = self.state.lock().unwrap();
|
let mut state = self.state.lock_unchecked();
|
||||||
state.approved_commands.insert(cmd);
|
state.approved_commands.insert(cmd);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -617,14 +630,14 @@ impl Session {
|
|||||||
debug!("Recording items for conversation: {items:?}");
|
debug!("Recording items for conversation: {items:?}");
|
||||||
self.record_state_snapshot(items).await;
|
self.record_state_snapshot(items).await;
|
||||||
|
|
||||||
self.state.lock().unwrap().history.record_items(items);
|
self.state.lock_unchecked().history.record_items(items);
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn record_state_snapshot(&self, items: &[ResponseItem]) {
|
async fn record_state_snapshot(&self, items: &[ResponseItem]) {
|
||||||
let snapshot = { crate::rollout::SessionStateSnapshot {} };
|
let snapshot = { crate::rollout::SessionStateSnapshot {} };
|
||||||
|
|
||||||
let recorder = {
|
let recorder = {
|
||||||
let guard = self.rollout.lock().unwrap();
|
let guard = self.rollout.lock_unchecked();
|
||||||
guard.as_ref().cloned()
|
guard.as_ref().cloned()
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -802,12 +815,12 @@ impl Session {
|
|||||||
/// Build the full turn input by concatenating the current conversation
|
/// Build the full turn input by concatenating the current conversation
|
||||||
/// history with additional items for this turn.
|
/// history with additional items for this turn.
|
||||||
pub fn turn_input_with_history(&self, extra: Vec<ResponseItem>) -> Vec<ResponseItem> {
|
pub fn turn_input_with_history(&self, extra: Vec<ResponseItem>) -> Vec<ResponseItem> {
|
||||||
[self.state.lock().unwrap().history.contents(), extra].concat()
|
[self.state.lock_unchecked().history.contents(), extra].concat()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns the input if there was no task running to inject into
|
/// Returns the input if there was no task running to inject into
|
||||||
pub fn inject_input(&self, input: Vec<InputItem>) -> Result<(), Vec<InputItem>> {
|
pub fn inject_input(&self, input: Vec<InputItem>) -> Result<(), Vec<InputItem>> {
|
||||||
let mut state = self.state.lock().unwrap();
|
let mut state = self.state.lock_unchecked();
|
||||||
if state.current_task.is_some() {
|
if state.current_task.is_some() {
|
||||||
state.pending_input.push(input.into());
|
state.pending_input.push(input.into());
|
||||||
Ok(())
|
Ok(())
|
||||||
@@ -817,7 +830,7 @@ impl Session {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_pending_input(&self) -> Vec<ResponseInputItem> {
|
pub fn get_pending_input(&self) -> Vec<ResponseInputItem> {
|
||||||
let mut state = self.state.lock().unwrap();
|
let mut state = self.state.lock_unchecked();
|
||||||
if state.pending_input.is_empty() {
|
if state.pending_input.is_empty() {
|
||||||
Vec::with_capacity(0)
|
Vec::with_capacity(0)
|
||||||
} else {
|
} else {
|
||||||
@@ -841,7 +854,7 @@ impl Session {
|
|||||||
|
|
||||||
fn abort(&self) {
|
fn abort(&self) {
|
||||||
info!("Aborting existing session");
|
info!("Aborting existing session");
|
||||||
let mut state = self.state.lock().unwrap();
|
let mut state = self.state.lock_unchecked();
|
||||||
state.pending_approvals.clear();
|
state.pending_approvals.clear();
|
||||||
state.pending_input.clear();
|
state.pending_input.clear();
|
||||||
if let Some(task) = state.current_task.take() {
|
if let Some(task) = state.current_task.take() {
|
||||||
@@ -1045,7 +1058,7 @@ async fn submission_loop(sess: Arc<Session>, config: Arc<Config>, rx_sub: Receiv
|
|||||||
|
|
||||||
// Gracefully flush and shutdown rollout recorder on session end so tests
|
// Gracefully flush and shutdown rollout recorder on session end so tests
|
||||||
// that inspect the rollout file do not race with the background writer.
|
// that inspect the rollout file do not race with the background writer.
|
||||||
let recorder_opt = sess.rollout.lock().unwrap().take();
|
let recorder_opt = sess.rollout.lock_unchecked().take();
|
||||||
if let Some(rec) = recorder_opt {
|
if let Some(rec) = recorder_opt {
|
||||||
if let Err(e) = rec.shutdown().await {
|
if let Err(e) = rec.shutdown().await {
|
||||||
warn!("failed to shutdown rollout recorder: {e}");
|
warn!("failed to shutdown rollout recorder: {e}");
|
||||||
@@ -1461,7 +1474,7 @@ async fn try_run_turn(
|
|||||||
}
|
}
|
||||||
ResponseEvent::OutputTextDelta(delta) => {
|
ResponseEvent::OutputTextDelta(delta) => {
|
||||||
{
|
{
|
||||||
let mut st = sess.state.lock().unwrap();
|
let mut st = sess.state.lock_unchecked();
|
||||||
st.history.append_assistant_text(&delta);
|
st.history.append_assistant_text(&delta);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1577,7 +1590,7 @@ async fn run_compact_task(
|
|||||||
};
|
};
|
||||||
sess.send_event(event).await;
|
sess.send_event(event).await;
|
||||||
|
|
||||||
let mut state = sess.state.lock().unwrap();
|
let mut state = sess.state.lock_unchecked();
|
||||||
state.history.keep_last_messages(1);
|
state.history.keep_last_messages(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1617,8 +1630,9 @@ async fn handle_response_item(
|
|||||||
};
|
};
|
||||||
sess.tx_event.send(event).await.ok();
|
sess.tx_event.send(event).await.ok();
|
||||||
}
|
}
|
||||||
if sess.show_raw_agent_reasoning && content.is_some() {
|
if sess.show_raw_agent_reasoning
|
||||||
let content = content.unwrap();
|
&& let Some(content) = content
|
||||||
|
{
|
||||||
for item in content {
|
for item in content {
|
||||||
let text = match item {
|
let text = match item {
|
||||||
ReasoningItemContent::ReasoningText { text } => text,
|
ReasoningItemContent::ReasoningText { text } => text,
|
||||||
@@ -1912,7 +1926,7 @@ async fn handle_container_exec_with_params(
|
|||||||
}
|
}
|
||||||
None => {
|
None => {
|
||||||
let safety = {
|
let safety = {
|
||||||
let state = sess.state.lock().unwrap();
|
let state = sess.state.lock_unchecked();
|
||||||
assess_command_safety(
|
assess_command_safety(
|
||||||
¶ms.command,
|
¶ms.command,
|
||||||
sess.approval_policy,
|
sess.approval_policy,
|
||||||
@@ -2252,7 +2266,7 @@ async fn drain_to_completed(sess: &Session, sub_id: &str, prompt: &Prompt) -> Co
|
|||||||
match event {
|
match event {
|
||||||
Ok(ResponseEvent::OutputItemDone(item)) => {
|
Ok(ResponseEvent::OutputItemDone(item)) => {
|
||||||
// Record only to in-memory conversation history; avoid state snapshot.
|
// Record only to in-memory conversation history; avoid state snapshot.
|
||||||
let mut state = sess.state.lock().unwrap();
|
let mut state = sess.state.lock_unchecked();
|
||||||
state.history.record_items(std::slice::from_ref(&item));
|
state.history.record_items(std::slice::from_ref(&item));
|
||||||
}
|
}
|
||||||
Ok(ResponseEvent::Completed {
|
Ok(ResponseEvent::Completed {
|
||||||
|
|||||||
Reference in New Issue
Block a user