use async_trait::async_trait; use std::future::Future; use tokio_util::sync::CancellationToken; #[derive(Debug, PartialEq, Eq)] pub enum CancelErr { Cancelled, } #[async_trait] pub trait OrCancelExt: Sized { type Output; async fn or_cancel(self, token: &CancellationToken) -> Result; } #[async_trait] impl OrCancelExt for F where F: Future + Send, F::Output: Send, { type Output = F::Output; async fn or_cancel(self, token: &CancellationToken) -> Result { tokio::select! { _ = token.cancelled() => Err(CancelErr::Cancelled), res = self => Ok(res), } } } #[cfg(test)] mod tests { use super::*; use pretty_assertions::assert_eq; use std::time::Duration; use tokio::task; use tokio::time::sleep; #[tokio::test] async fn returns_ok_when_future_completes_first() { let token = CancellationToken::new(); let value = async { 42 }; let result = value.or_cancel(&token).await; assert_eq!(Ok(42), result); } #[tokio::test] async fn returns_err_when_token_cancelled_first() { let token = CancellationToken::new(); let token_clone = token.clone(); let cancel_handle = task::spawn(async move { sleep(Duration::from_millis(10)).await; token_clone.cancel(); }); let result = async { sleep(Duration::from_millis(100)).await; 7 } .or_cancel(&token) .await; cancel_handle.await.expect("cancel task panicked"); assert_eq!(Err(CancelErr::Cancelled), result); } #[tokio::test] async fn returns_err_when_token_already_cancelled() { let token = CancellationToken::new(); token.cancel(); let result = async { sleep(Duration::from_millis(50)).await; 5 } .or_cancel(&token) .await; assert_eq!(Err(CancelErr::Cancelled), result); } }