From 67736cef016686c11dbf142c309c4e5c4f0f0e36 Mon Sep 17 00:00:00 2001 From: Vicki Pfau Date: Wed, 29 May 2024 18:35:13 -0700 Subject: [PATCH] thread: Add AsyncJoinHandle for waiting on threads asynchronously --- src/lib.rs | 1 + src/thread.rs | 113 ++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 114 insertions(+) create mode 100644 src/thread.rs diff --git a/src/lib.rs b/src/lib.rs index b0a1827..2df9e9a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -19,6 +19,7 @@ mod manager; mod process; mod sls; mod systemd; +mod thread; pub mod cec; pub mod daemon; diff --git a/src/thread.rs b/src/thread.rs new file mode 100644 index 0000000..9a11e2b --- /dev/null +++ b/src/thread.rs @@ -0,0 +1,113 @@ +/* + * Copyright © 2023 Collabora Ltd. + * Copyright © 2024 Valve Software + * + * SPDX-License-Identifier: MIT + */ + +use std::future::Future; +use std::pin::Pin; +use std::sync::{Arc, Mutex}; +use std::task::{Context, Poll, Waker}; +use std::thread::{self, JoinHandle}; + +pub(crate) struct AsyncJoinHandle +where + T: Send + 'static, +{ + join_handle: Option>, + context: Arc>, +} + +struct JoinContext { + waker: Option, + exited: bool, +} + +struct JoinGuard { + context: Arc>, +} + +impl Future for AsyncJoinHandle { + type Output = T; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = Pin::into_inner(self); + let guard = this.context.lock(); + let mut context = guard.unwrap(); + context.waker.replace(cx.waker().clone()); + if let Some(join_handle) = this.join_handle.as_mut() { + if join_handle.is_finished() || context.exited { + let join_handle = this.join_handle.take().unwrap(); + return Poll::Ready(join_handle.join().unwrap()); + } + } + Poll::Pending + } +} + +impl Drop for JoinGuard { + fn drop(&mut self) { + let guard = self.context.lock(); + let mut context = guard.unwrap(); + context.exited = true; + let waker = context.waker.take(); + if let Some(waker) = waker { + waker.wake(); + } + } +} + +pub(crate) fn spawn(f: F) -> AsyncJoinHandle +where + F: FnOnce() -> T + Send + 'static, + T: Send + 'static, +{ + let context = Arc::new(Mutex::new(JoinContext { + waker: None, + exited: false, + })); + + let thread_context = context.clone(); + let join_handle = Some(thread::spawn(move || { + let _guard = JoinGuard { + context: thread_context, + }; + f() + })); + + AsyncJoinHandle { + join_handle, + context, + } +} + +#[cfg(test)] +mod test { + use super::*; + use std::thread::sleep as sync_sleep; + use std::time::Duration; + use tokio::time::sleep as async_sleep; + + #[tokio::test] + async fn test_join() { + let handle = spawn(|| true); + assert!(handle.await); + } + + #[tokio::test] + async fn test_slow_join() { + let handle = spawn(|| true); + async_sleep(Duration::from_millis(100)).await; + assert!(handle.await); + } + + #[tokio::test] + async fn test_slow_thread() { + let handle = spawn(|| { + sync_sleep(Duration::from_millis(100)); + true + }); + assert!(handle.await); + } +}