fix: race condition unified exec (#3644)
Fix race condition without storing an rx in the session
This commit is contained in:
@@ -11,9 +11,6 @@ pub(crate) struct ExecCommandSession {
|
|||||||
/// Broadcast stream of output chunks read from the PTY. New subscribers
|
/// Broadcast stream of output chunks read from the PTY. New subscribers
|
||||||
/// receive only chunks emitted after they subscribe.
|
/// receive only chunks emitted after they subscribe.
|
||||||
output_tx: broadcast::Sender<Vec<u8>>,
|
output_tx: broadcast::Sender<Vec<u8>>,
|
||||||
/// Receiver subscribed before the child process starts emitting output so
|
|
||||||
/// the first caller can consume any early data without races.
|
|
||||||
initial_output_rx: StdMutex<Option<broadcast::Receiver<Vec<u8>>>>,
|
|
||||||
|
|
||||||
/// Child killer handle for termination on drop (can signal independently
|
/// Child killer handle for termination on drop (can signal independently
|
||||||
/// of a thread blocked in `.wait()`).
|
/// of a thread blocked in `.wait()`).
|
||||||
@@ -41,25 +38,20 @@ impl ExecCommandSession {
|
|||||||
writer_handle: JoinHandle<()>,
|
writer_handle: JoinHandle<()>,
|
||||||
wait_handle: JoinHandle<()>,
|
wait_handle: JoinHandle<()>,
|
||||||
exit_status: std::sync::Arc<std::sync::atomic::AtomicBool>,
|
exit_status: std::sync::Arc<std::sync::atomic::AtomicBool>,
|
||||||
) -> Self {
|
) -> (Self, broadcast::Receiver<Vec<u8>>) {
|
||||||
Self {
|
let initial_output_rx = output_tx.subscribe();
|
||||||
writer_tx,
|
(
|
||||||
output_tx,
|
Self {
|
||||||
initial_output_rx: StdMutex::new(None),
|
writer_tx,
|
||||||
killer: StdMutex::new(Some(killer)),
|
output_tx,
|
||||||
reader_handle: StdMutex::new(Some(reader_handle)),
|
killer: StdMutex::new(Some(killer)),
|
||||||
writer_handle: StdMutex::new(Some(writer_handle)),
|
reader_handle: StdMutex::new(Some(reader_handle)),
|
||||||
wait_handle: StdMutex::new(Some(wait_handle)),
|
writer_handle: StdMutex::new(Some(writer_handle)),
|
||||||
exit_status,
|
wait_handle: StdMutex::new(Some(wait_handle)),
|
||||||
}
|
exit_status,
|
||||||
}
|
},
|
||||||
|
initial_output_rx,
|
||||||
pub(crate) fn set_initial_output_receiver(&self, receiver: broadcast::Receiver<Vec<u8>>) {
|
)
|
||||||
if let Ok(mut guard) = self.initial_output_rx.lock()
|
|
||||||
&& guard.is_none()
|
|
||||||
{
|
|
||||||
*guard = Some(receiver);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn writer_sender(&self) -> mpsc::Sender<Vec<u8>> {
|
pub(crate) fn writer_sender(&self) -> mpsc::Sender<Vec<u8>> {
|
||||||
@@ -67,13 +59,7 @@ impl ExecCommandSession {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn output_receiver(&self) -> broadcast::Receiver<Vec<u8>> {
|
pub(crate) fn output_receiver(&self) -> broadcast::Receiver<Vec<u8>> {
|
||||||
if let Ok(mut guard) = self.initial_output_rx.lock()
|
self.output_tx.subscribe()
|
||||||
&& let Some(receiver) = guard.take()
|
|
||||||
{
|
|
||||||
receiver
|
|
||||||
} else {
|
|
||||||
self.output_tx.subscribe()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn has_exited(&self) -> bool {
|
pub(crate) fn has_exited(&self) -> bool {
|
||||||
|
|||||||
@@ -93,18 +93,16 @@ impl SessionManager {
|
|||||||
.fetch_add(1, std::sync::atomic::Ordering::SeqCst),
|
.fetch_add(1, std::sync::atomic::Ordering::SeqCst),
|
||||||
);
|
);
|
||||||
|
|
||||||
let (session, mut exit_rx) =
|
let (session, mut output_rx, mut exit_rx) = create_exec_command_session(params.clone())
|
||||||
create_exec_command_session(params.clone())
|
.await
|
||||||
.await
|
.map_err(|err| {
|
||||||
.map_err(|err| {
|
format!(
|
||||||
format!(
|
"failed to create exec command session for session id {}: {err}",
|
||||||
"failed to create exec command session for session id {}: {err}",
|
session_id.0
|
||||||
session_id.0
|
)
|
||||||
)
|
})?;
|
||||||
})?;
|
|
||||||
|
|
||||||
// Insert into session map.
|
// Insert into session map.
|
||||||
let mut output_rx = session.output_receiver();
|
|
||||||
self.sessions.lock().await.insert(session_id, session);
|
self.sessions.lock().await.insert(session_id, session);
|
||||||
|
|
||||||
// Collect output until either timeout expires or process exits.
|
// Collect output until either timeout expires or process exits.
|
||||||
@@ -245,7 +243,11 @@ impl SessionManager {
|
|||||||
/// Spawn PTY and child process per spawn_exec_command_session logic.
|
/// Spawn PTY and child process per spawn_exec_command_session logic.
|
||||||
async fn create_exec_command_session(
|
async fn create_exec_command_session(
|
||||||
params: ExecCommandParams,
|
params: ExecCommandParams,
|
||||||
) -> anyhow::Result<(ExecCommandSession, oneshot::Receiver<i32>)> {
|
) -> anyhow::Result<(
|
||||||
|
ExecCommandSession,
|
||||||
|
tokio::sync::broadcast::Receiver<Vec<u8>>,
|
||||||
|
oneshot::Receiver<i32>,
|
||||||
|
)> {
|
||||||
let ExecCommandParams {
|
let ExecCommandParams {
|
||||||
cmd,
|
cmd,
|
||||||
yield_time_ms: _,
|
yield_time_ms: _,
|
||||||
@@ -279,8 +281,6 @@ async fn create_exec_command_session(
|
|||||||
let (writer_tx, mut writer_rx) = mpsc::channel::<Vec<u8>>(128);
|
let (writer_tx, mut writer_rx) = mpsc::channel::<Vec<u8>>(128);
|
||||||
// Broadcast for streaming PTY output to readers: subscribers receive from subscription time.
|
// Broadcast for streaming PTY output to readers: subscribers receive from subscription time.
|
||||||
let (output_tx, _) = tokio::sync::broadcast::channel::<Vec<u8>>(256);
|
let (output_tx, _) = tokio::sync::broadcast::channel::<Vec<u8>>(256);
|
||||||
let initial_output_rx = output_tx.subscribe();
|
|
||||||
|
|
||||||
// Reader task: drain PTY and forward chunks to output channel.
|
// Reader task: drain PTY and forward chunks to output channel.
|
||||||
let mut reader = pair.master.try_clone_reader()?;
|
let mut reader = pair.master.try_clone_reader()?;
|
||||||
let output_tx_clone = output_tx.clone();
|
let output_tx_clone = output_tx.clone();
|
||||||
@@ -342,7 +342,7 @@ async fn create_exec_command_session(
|
|||||||
});
|
});
|
||||||
|
|
||||||
// Create and store the session with channels.
|
// Create and store the session with channels.
|
||||||
let session = ExecCommandSession::new(
|
let (session, initial_output_rx) = ExecCommandSession::new(
|
||||||
writer_tx,
|
writer_tx,
|
||||||
output_tx,
|
output_tx,
|
||||||
killer,
|
killer,
|
||||||
@@ -351,8 +351,7 @@ async fn create_exec_command_session(
|
|||||||
wait_handle,
|
wait_handle,
|
||||||
exit_status,
|
exit_status,
|
||||||
);
|
);
|
||||||
session.set_initial_output_receiver(initial_output_rx);
|
Ok((session, initial_output_rx, exit_rx))
|
||||||
Ok((session, exit_rx))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
|||||||
@@ -100,10 +100,13 @@ type OutputBuffer = Arc<Mutex<OutputBufferState>>;
|
|||||||
type OutputHandles = (OutputBuffer, Arc<Notify>);
|
type OutputHandles = (OutputBuffer, Arc<Notify>);
|
||||||
|
|
||||||
impl ManagedUnifiedExecSession {
|
impl ManagedUnifiedExecSession {
|
||||||
fn new(session: ExecCommandSession) -> Self {
|
fn new(
|
||||||
|
session: ExecCommandSession,
|
||||||
|
initial_output_rx: tokio::sync::broadcast::Receiver<Vec<u8>>,
|
||||||
|
) -> Self {
|
||||||
let output_buffer = Arc::new(Mutex::new(OutputBufferState::default()));
|
let output_buffer = Arc::new(Mutex::new(OutputBufferState::default()));
|
||||||
let output_notify = Arc::new(Notify::new());
|
let output_notify = Arc::new(Notify::new());
|
||||||
let mut receiver = session.output_receiver();
|
let mut receiver = initial_output_rx;
|
||||||
let buffer_clone = Arc::clone(&output_buffer);
|
let buffer_clone = Arc::clone(&output_buffer);
|
||||||
let notify_clone = Arc::clone(&output_notify);
|
let notify_clone = Arc::clone(&output_notify);
|
||||||
let output_task = tokio::spawn(async move {
|
let output_task = tokio::spawn(async move {
|
||||||
@@ -193,8 +196,8 @@ impl UnifiedExecSessionManager {
|
|||||||
} else {
|
} else {
|
||||||
let command = request.input_chunks.to_vec();
|
let command = request.input_chunks.to_vec();
|
||||||
let new_id = self.next_session_id.fetch_add(1, Ordering::SeqCst);
|
let new_id = self.next_session_id.fetch_add(1, Ordering::SeqCst);
|
||||||
let session = create_unified_exec_session(&command).await?;
|
let (session, initial_output_rx) = create_unified_exec_session(&command).await?;
|
||||||
let managed_session = ManagedUnifiedExecSession::new(session);
|
let managed_session = ManagedUnifiedExecSession::new(session, initial_output_rx);
|
||||||
let (buffer, notify) = managed_session.output_handles();
|
let (buffer, notify) = managed_session.output_handles();
|
||||||
writer_tx = managed_session.writer_sender();
|
writer_tx = managed_session.writer_sender();
|
||||||
output_buffer = buffer;
|
output_buffer = buffer;
|
||||||
@@ -297,7 +300,13 @@ impl UnifiedExecSessionManager {
|
|||||||
|
|
||||||
async fn create_unified_exec_session(
|
async fn create_unified_exec_session(
|
||||||
command: &[String],
|
command: &[String],
|
||||||
) -> Result<ExecCommandSession, UnifiedExecError> {
|
) -> Result<
|
||||||
|
(
|
||||||
|
ExecCommandSession,
|
||||||
|
tokio::sync::broadcast::Receiver<Vec<u8>>,
|
||||||
|
),
|
||||||
|
UnifiedExecError,
|
||||||
|
> {
|
||||||
if command.is_empty() {
|
if command.is_empty() {
|
||||||
return Err(UnifiedExecError::MissingCommandLine);
|
return Err(UnifiedExecError::MissingCommandLine);
|
||||||
}
|
}
|
||||||
@@ -327,7 +336,6 @@ async fn create_unified_exec_session(
|
|||||||
|
|
||||||
let (writer_tx, mut writer_rx) = mpsc::channel::<Vec<u8>>(128);
|
let (writer_tx, mut writer_rx) = mpsc::channel::<Vec<u8>>(128);
|
||||||
let (output_tx, _) = tokio::sync::broadcast::channel::<Vec<u8>>(256);
|
let (output_tx, _) = tokio::sync::broadcast::channel::<Vec<u8>>(256);
|
||||||
let initial_output_rx = output_tx.subscribe();
|
|
||||||
|
|
||||||
let mut reader = pair
|
let mut reader = pair
|
||||||
.master
|
.master
|
||||||
@@ -381,7 +389,7 @@ async fn create_unified_exec_session(
|
|||||||
wait_exit_status.store(true, Ordering::SeqCst);
|
wait_exit_status.store(true, Ordering::SeqCst);
|
||||||
});
|
});
|
||||||
|
|
||||||
let session = ExecCommandSession::new(
|
let (session, initial_output_rx) = ExecCommandSession::new(
|
||||||
writer_tx,
|
writer_tx,
|
||||||
output_tx,
|
output_tx,
|
||||||
killer,
|
killer,
|
||||||
@@ -390,9 +398,7 @@ async fn create_unified_exec_session(
|
|||||||
wait_handle,
|
wait_handle,
|
||||||
exit_status,
|
exit_status,
|
||||||
);
|
);
|
||||||
session.set_initial_output_receiver(initial_output_rx);
|
Ok((session, initial_output_rx))
|
||||||
|
|
||||||
Ok(session)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
|||||||
Reference in New Issue
Block a user