diff --git a/src/daemon.rs b/src/daemon.rs new file mode 100644 index 0000000..54a5e72 --- /dev/null +++ b/src/daemon.rs @@ -0,0 +1,87 @@ +/* + * Copyright © 2023 Collabora Ltd. + * Copyright © 2024 Valve Software + * + * SPDX-License-Identifier: MIT + */ + +use anyhow::{anyhow, Result}; +use tokio::signal::unix::{signal, Signal, SignalKind}; +use tokio::task::JoinSet; +use tokio_util::sync::CancellationToken; +use tracing::{error, info}; +use tracing_subscriber::layer::SubscriberExt; +use tracing_subscriber::registry::LookupSpan; +use zbus::connection::Connection; + +use crate::sls::{LogLayer, LogReceiver}; +use crate::{reload, Service}; + +pub struct Daemon { + services: JoinSet>, + token: CancellationToken, + sigterm: Signal, + sigquit: Signal, +} + +impl Daemon { + pub async fn new LookupSpan<'a>>( + subscriber: S, + connection: Connection, + ) -> Result { + let services = JoinSet::new(); + let token = CancellationToken::new(); + + let log_receiver = LogReceiver::new(connection.clone()).await?; + let remote_logger = LogLayer::new(&log_receiver).await; + let subscriber = subscriber.with(remote_logger); + tracing::subscriber::set_global_default(subscriber)?; + + let sigterm = signal(SignalKind::terminate())?; + let sigquit = signal(SignalKind::quit())?; + + let mut daemon = Daemon { + services, + token, + sigterm, + sigquit, + }; + daemon.add_service(log_receiver); + + Ok(daemon) + } + + pub fn add_service(&mut self, service: S) { + let token = self.token.clone(); + self.services + .spawn(async move { service.start(token).await }); + } + + pub async fn run(&mut self) -> Result<()> { + let mut res = tokio::select! { + e = self.services.join_next() => match e.unwrap() { + Ok(Ok(())) => Ok(()), + Ok(Err(e)) => Err(e), + Err(e) => Err(e.into()) + }, + _ = tokio::signal::ctrl_c() => Ok(()), + e = self.sigterm.recv() => e.ok_or(anyhow!("SIGTERM machine broke")), + _ = self.sigquit.recv() => Err(anyhow!("Got SIGQUIT")), + e = reload() => e, + } + .inspect_err(|e| error!("Encountered error running: {e}")); + self.token.cancel(); + + info!("Shutting down"); + + while let Some(service_res) = self.services.join_next().await { + res = match service_res { + Ok(Err(e)) => Err(e), + Err(e) => Err(e.into()), + _ => continue, + }; + } + + res.inspect_err(|e| error!("Encountered error: {e}")) + } +} diff --git a/src/main.rs b/src/main.rs index 5e489f3..06e2816 100644 --- a/src/main.rs +++ b/src/main.rs @@ -7,6 +7,7 @@ use anyhow::{anyhow, Result}; use clap::Parser; +use std::future::Future; use std::path::{Path, PathBuf}; use tokio::fs::File; use tokio::io::AsyncWriteExt; @@ -14,6 +15,7 @@ use tokio::signal::unix::{signal, SignalKind}; use tokio_util::sync::CancellationToken; use tracing::{info, warn}; +mod daemon; mod ds_inhibit; mod hardware; mod manager; @@ -29,32 +31,34 @@ mod testing; trait Service where - Self: Sized, + Self: Sized + Send, { const NAME: &'static str; - async fn run(&mut self) -> Result<()>; + fn run(&mut self) -> impl Future> + Send; - async fn shutdown(&mut self) -> Result<()> { - Ok(()) + fn shutdown(&mut self) -> impl Future> + Send { + async { Ok(()) } } - async fn start(mut self, token: CancellationToken) -> Result<()> { - info!("Starting {}", Self::NAME); - let res = tokio::select! { - r = self.run() => r, - _ = token.cancelled() => Ok(()), - }; - if res.is_err() { - warn!( - "{} encountered an error: {}", - Self::NAME, - res.as_ref().unwrap_err() - ); - token.cancel(); + fn start(mut self, token: CancellationToken) -> impl Future> + Send { + async move { + info!("Starting {}", Self::NAME); + let res = tokio::select! { + r = self.run() => r, + _ = token.cancelled() => Ok(()), + }; + if res.is_err() { + warn!( + "{} encountered an error: {}", + Self::NAME, + res.as_ref().unwrap_err() + ); + token.cancel(); + } + info!("Shutting down {}", Self::NAME); + self.shutdown().await.and(res) } - info!("Shutting down {}", Self::NAME); - self.shutdown().await.and(res) } } @@ -62,7 +66,7 @@ where struct Args { /// Run the root manager daemon #[arg(short, long)] - root: bool + root: bool, } #[cfg(not(test))] diff --git a/src/root.rs b/src/root.rs index 79bc06c..525a8be 100644 --- a/src/root.rs +++ b/src/root.rs @@ -5,20 +5,17 @@ * SPDX-License-Identifier: MIT */ -use anyhow::{anyhow, bail, Result}; -use tokio::signal::unix::{signal, SignalKind}; -use tokio::task::JoinSet; -use tokio_util::sync::CancellationToken; -use tracing::{error, info}; +use anyhow::{bail, Result}; +use tracing::error; use tracing_subscriber::prelude::*; use tracing_subscriber::{fmt, Registry}; use zbus::connection::Connection; use zbus::ConnectionBuilder; +use crate::daemon::Daemon; use crate::ds_inhibit::Inhibitor; -use crate::{manager, reload, Service}; +use crate::manager; use crate::sls::ftrace::Ftrace; -use crate::sls::{LogLayer, LogReceiver}; async fn create_connection() -> Result { let connection = ConnectionBuilder::system()? @@ -48,48 +45,13 @@ pub async fn daemon() -> Result<()> { bail!(e); } }; - - let mut services = JoinSet::new(); - let token = CancellationToken::new(); - - let mut log_receiver = LogReceiver::new(connection.clone()).await?; - let remote_logger = LogLayer::new(&log_receiver).await; - let subscriber = subscriber.with(remote_logger); - tracing::subscriber::set_global_default(subscriber)?; - - let mut sigterm = signal(SignalKind::terminate())?; - let mut sigquit = signal(SignalKind::quit())?; + let mut daemon = Daemon::new(subscriber, connection.clone()).await?; let ftrace = Ftrace::init(connection.clone()).await?; - services.spawn(ftrace.start(token.clone())); + daemon.add_service(ftrace); let inhibitor = Inhibitor::init().await?; - services.spawn(inhibitor.start(token.clone())); + daemon.add_service(inhibitor); - let mut res = tokio::select! { - e = log_receiver.run() => e, - e = services.join_next() => match e.unwrap() { - Ok(Ok(())) => Ok(()), - Ok(Err(e)) => Err(e), - Err(e) => Err(e.into()) - }, - _ = tokio::signal::ctrl_c() => Ok(()), - e = sigterm.recv() => e.ok_or(anyhow!("SIGTERM machine broke")), - _ = sigquit.recv() => Err(anyhow!("Got SIGQUIT")), - e = reload() => e, - } - .inspect_err(|e| error!("Encountered error running: {e}")); - token.cancel(); - - info!("Shutting down"); - - while let Some(service_res) = services.join_next().await { - res = match service_res { - Ok(Err(e)) => Err(e), - Err(e) => Err(e.into()), - _ => continue, - }; - } - - res.inspect_err(|e| error!("Encountered error: {e}")) + daemon.run().await }