fix: race condition unified exec (#3644)

Fix race condition without storing an rx in the session
This commit is contained in:
jimmyfraiture2
2025-09-15 14:52:39 +01:00
committed by GitHub
parent 9baa5c33da
commit d555b68469
3 changed files with 46 additions and 55 deletions

View File

@@ -11,9 +11,6 @@ pub(crate) struct ExecCommandSession {
/// Broadcast stream of output chunks read from the PTY. New subscribers /// Broadcast stream of output chunks read from the PTY. New subscribers
/// receive only chunks emitted after they subscribe. /// receive only chunks emitted after they subscribe.
output_tx: broadcast::Sender<Vec<u8>>, output_tx: broadcast::Sender<Vec<u8>>,
/// Receiver subscribed before the child process starts emitting output so
/// the first caller can consume any early data without races.
initial_output_rx: StdMutex<Option<broadcast::Receiver<Vec<u8>>>>,
/// Child killer handle for termination on drop (can signal independently /// Child killer handle for termination on drop (can signal independently
/// of a thread blocked in `.wait()`). /// of a thread blocked in `.wait()`).
@@ -41,25 +38,20 @@ impl ExecCommandSession {
writer_handle: JoinHandle<()>, writer_handle: JoinHandle<()>,
wait_handle: JoinHandle<()>, wait_handle: JoinHandle<()>,
exit_status: std::sync::Arc<std::sync::atomic::AtomicBool>, exit_status: std::sync::Arc<std::sync::atomic::AtomicBool>,
) -> Self { ) -> (Self, broadcast::Receiver<Vec<u8>>) {
Self { let initial_output_rx = output_tx.subscribe();
writer_tx, (
output_tx, Self {
initial_output_rx: StdMutex::new(None), writer_tx,
killer: StdMutex::new(Some(killer)), output_tx,
reader_handle: StdMutex::new(Some(reader_handle)), killer: StdMutex::new(Some(killer)),
writer_handle: StdMutex::new(Some(writer_handle)), reader_handle: StdMutex::new(Some(reader_handle)),
wait_handle: StdMutex::new(Some(wait_handle)), writer_handle: StdMutex::new(Some(writer_handle)),
exit_status, wait_handle: StdMutex::new(Some(wait_handle)),
} exit_status,
} },
initial_output_rx,
pub(crate) fn set_initial_output_receiver(&self, receiver: broadcast::Receiver<Vec<u8>>) { )
if let Ok(mut guard) = self.initial_output_rx.lock()
&& guard.is_none()
{
*guard = Some(receiver);
}
} }
pub(crate) fn writer_sender(&self) -> mpsc::Sender<Vec<u8>> { pub(crate) fn writer_sender(&self) -> mpsc::Sender<Vec<u8>> {
@@ -67,13 +59,7 @@ impl ExecCommandSession {
} }
pub(crate) fn output_receiver(&self) -> broadcast::Receiver<Vec<u8>> { pub(crate) fn output_receiver(&self) -> broadcast::Receiver<Vec<u8>> {
if let Ok(mut guard) = self.initial_output_rx.lock() self.output_tx.subscribe()
&& let Some(receiver) = guard.take()
{
receiver
} else {
self.output_tx.subscribe()
}
} }
pub(crate) fn has_exited(&self) -> bool { pub(crate) fn has_exited(&self) -> bool {

View File

@@ -93,18 +93,16 @@ impl SessionManager {
.fetch_add(1, std::sync::atomic::Ordering::SeqCst), .fetch_add(1, std::sync::atomic::Ordering::SeqCst),
); );
let (session, mut exit_rx) = let (session, mut output_rx, mut exit_rx) = create_exec_command_session(params.clone())
create_exec_command_session(params.clone()) .await
.await .map_err(|err| {
.map_err(|err| { format!(
format!( "failed to create exec command session for session id {}: {err}",
"failed to create exec command session for session id {}: {err}", session_id.0
session_id.0 )
) })?;
})?;
// Insert into session map. // Insert into session map.
let mut output_rx = session.output_receiver();
self.sessions.lock().await.insert(session_id, session); self.sessions.lock().await.insert(session_id, session);
// Collect output until either timeout expires or process exits. // Collect output until either timeout expires or process exits.
@@ -245,7 +243,11 @@ impl SessionManager {
/// Spawn PTY and child process per spawn_exec_command_session logic. /// Spawn PTY and child process per spawn_exec_command_session logic.
async fn create_exec_command_session( async fn create_exec_command_session(
params: ExecCommandParams, params: ExecCommandParams,
) -> anyhow::Result<(ExecCommandSession, oneshot::Receiver<i32>)> { ) -> anyhow::Result<(
ExecCommandSession,
tokio::sync::broadcast::Receiver<Vec<u8>>,
oneshot::Receiver<i32>,
)> {
let ExecCommandParams { let ExecCommandParams {
cmd, cmd,
yield_time_ms: _, yield_time_ms: _,
@@ -279,8 +281,6 @@ async fn create_exec_command_session(
let (writer_tx, mut writer_rx) = mpsc::channel::<Vec<u8>>(128); let (writer_tx, mut writer_rx) = mpsc::channel::<Vec<u8>>(128);
// Broadcast for streaming PTY output to readers: subscribers receive from subscription time. // Broadcast for streaming PTY output to readers: subscribers receive from subscription time.
let (output_tx, _) = tokio::sync::broadcast::channel::<Vec<u8>>(256); let (output_tx, _) = tokio::sync::broadcast::channel::<Vec<u8>>(256);
let initial_output_rx = output_tx.subscribe();
// Reader task: drain PTY and forward chunks to output channel. // Reader task: drain PTY and forward chunks to output channel.
let mut reader = pair.master.try_clone_reader()?; let mut reader = pair.master.try_clone_reader()?;
let output_tx_clone = output_tx.clone(); let output_tx_clone = output_tx.clone();
@@ -342,7 +342,7 @@ async fn create_exec_command_session(
}); });
// Create and store the session with channels. // Create and store the session with channels.
let session = ExecCommandSession::new( let (session, initial_output_rx) = ExecCommandSession::new(
writer_tx, writer_tx,
output_tx, output_tx,
killer, killer,
@@ -351,8 +351,7 @@ async fn create_exec_command_session(
wait_handle, wait_handle,
exit_status, exit_status,
); );
session.set_initial_output_receiver(initial_output_rx); Ok((session, initial_output_rx, exit_rx))
Ok((session, exit_rx))
} }
#[cfg(test)] #[cfg(test)]

View File

@@ -100,10 +100,13 @@ type OutputBuffer = Arc<Mutex<OutputBufferState>>;
type OutputHandles = (OutputBuffer, Arc<Notify>); type OutputHandles = (OutputBuffer, Arc<Notify>);
impl ManagedUnifiedExecSession { impl ManagedUnifiedExecSession {
fn new(session: ExecCommandSession) -> Self { fn new(
session: ExecCommandSession,
initial_output_rx: tokio::sync::broadcast::Receiver<Vec<u8>>,
) -> Self {
let output_buffer = Arc::new(Mutex::new(OutputBufferState::default())); let output_buffer = Arc::new(Mutex::new(OutputBufferState::default()));
let output_notify = Arc::new(Notify::new()); let output_notify = Arc::new(Notify::new());
let mut receiver = session.output_receiver(); let mut receiver = initial_output_rx;
let buffer_clone = Arc::clone(&output_buffer); let buffer_clone = Arc::clone(&output_buffer);
let notify_clone = Arc::clone(&output_notify); let notify_clone = Arc::clone(&output_notify);
let output_task = tokio::spawn(async move { let output_task = tokio::spawn(async move {
@@ -193,8 +196,8 @@ impl UnifiedExecSessionManager {
} else { } else {
let command = request.input_chunks.to_vec(); let command = request.input_chunks.to_vec();
let new_id = self.next_session_id.fetch_add(1, Ordering::SeqCst); let new_id = self.next_session_id.fetch_add(1, Ordering::SeqCst);
let session = create_unified_exec_session(&command).await?; let (session, initial_output_rx) = create_unified_exec_session(&command).await?;
let managed_session = ManagedUnifiedExecSession::new(session); let managed_session = ManagedUnifiedExecSession::new(session, initial_output_rx);
let (buffer, notify) = managed_session.output_handles(); let (buffer, notify) = managed_session.output_handles();
writer_tx = managed_session.writer_sender(); writer_tx = managed_session.writer_sender();
output_buffer = buffer; output_buffer = buffer;
@@ -297,7 +300,13 @@ impl UnifiedExecSessionManager {
async fn create_unified_exec_session( async fn create_unified_exec_session(
command: &[String], command: &[String],
) -> Result<ExecCommandSession, UnifiedExecError> { ) -> Result<
(
ExecCommandSession,
tokio::sync::broadcast::Receiver<Vec<u8>>,
),
UnifiedExecError,
> {
if command.is_empty() { if command.is_empty() {
return Err(UnifiedExecError::MissingCommandLine); return Err(UnifiedExecError::MissingCommandLine);
} }
@@ -327,7 +336,6 @@ async fn create_unified_exec_session(
let (writer_tx, mut writer_rx) = mpsc::channel::<Vec<u8>>(128); let (writer_tx, mut writer_rx) = mpsc::channel::<Vec<u8>>(128);
let (output_tx, _) = tokio::sync::broadcast::channel::<Vec<u8>>(256); let (output_tx, _) = tokio::sync::broadcast::channel::<Vec<u8>>(256);
let initial_output_rx = output_tx.subscribe();
let mut reader = pair let mut reader = pair
.master .master
@@ -381,7 +389,7 @@ async fn create_unified_exec_session(
wait_exit_status.store(true, Ordering::SeqCst); wait_exit_status.store(true, Ordering::SeqCst);
}); });
let session = ExecCommandSession::new( let (session, initial_output_rx) = ExecCommandSession::new(
writer_tx, writer_tx,
output_tx, output_tx,
killer, killer,
@@ -390,9 +398,7 @@ async fn create_unified_exec_session(
wait_handle, wait_handle,
exit_status, exit_status,
); );
session.set_initial_output_receiver(initial_output_rx); Ok((session, initial_output_rx))
Ok(session)
} }
#[cfg(test)] #[cfg(test)]