use anyhow::{anyhow, bail, Result}; use libc::pid_t; use nix::sys::signal; use nix::sys::signal::Signal; use nix::unistd::Pid; use std::cell::{Cell, RefCell}; use std::collections::{HashMap, HashSet}; use std::ffi::OsStr; use std::iter::zip; use std::path::Path; use std::process::Stdio; use std::rc::Rc; use std::str::FromStr; use std::time::Duration; use tempfile::{tempdir, TempDir}; use tokio::fs::read; use tokio::io::{AsyncBufReadExt, BufReader}; use tokio::process::{Child, Command}; use tokio::sync::Mutex; use tracing::error; use zbus::zvariant::ObjectPath; use zbus::{Address, Connection, ConnectionBuilder, Interface}; use zbus_xml::{Method, Node, Property}; use crate::platform::PlatformConfig; thread_local! { static TEST: RefCell>> = const { RefCell::new(None) }; } #[macro_export] macro_rules! enum_roundtrip { ($enum:ident => $value:literal : str = $variant:ident) => { assert_eq!($enum::$variant.to_string(), $value); assert_eq!($enum::from_str($value).unwrap(), $enum::$variant); }; ($enum:ident => $value:literal : $ty:ty = $variant:ident) => { assert_eq!($enum::$variant as $ty, $value); assert_eq!($enum::try_from($value).unwrap(), $enum::$variant); }; ($enum:ident { $($value:literal : $ty:ident = $variant:ident,)+ }) => { $(enum_roundtrip!($enum => $value : $ty = $variant);)+ }; } #[macro_export] macro_rules! enum_on_off { ($enum:ident => ($on:ident, $off:ident)) => { assert_eq!($enum::from_str("on").unwrap(), $enum::$on); assert_eq!($enum::from_str("On").unwrap(), $enum::$on); assert_eq!($enum::from_str("enable").unwrap(), $enum::$on); assert_eq!($enum::from_str("enabled").unwrap(), $enum::$on); assert_eq!($enum::from_str("1").unwrap(), $enum::$on); assert_eq!($enum::from_str("off").unwrap(), $enum::$off); assert_eq!($enum::from_str("Off").unwrap(), $enum::$off); assert_eq!($enum::from_str("disable").unwrap(), $enum::$off); assert_eq!($enum::from_str("disabled").unwrap(), $enum::$off); assert_eq!($enum::from_str("0").unwrap(), $enum::$off); }; } pub fn start() -> TestHandle { TEST.with(|lock| { assert!(lock.borrow().as_ref().is_none()); let test: Rc = Rc::new(Test { base: tempdir().expect("Couldn't create test directory"), process_cb: Cell::new(|_, _| Err(anyhow!("No current process_cb"))), mock_dbus: Cell::new(None), dbus_address: Mutex::new(None), platform_config: RefCell::new(None), }); *lock.borrow_mut() = Some(test.clone()); TestHandle { test } }) } pub fn stop() { TEST.with(|lock| { let test = (*lock.borrow_mut()).take(); if let Some(test) = test { if let Some(mock_dbus) = test.mock_dbus.take() { let _ = mock_dbus.shutdown(); } } }); } pub fn current() -> Rc { TEST.with(|lock| lock.borrow().as_ref().unwrap().clone()) } pub struct MockDBus { pub connection: Connection, address: Address, process: Child, } pub struct Test { base: TempDir, pub process_cb: Cell Result<(i32, String)>>, pub mock_dbus: Cell>, pub dbus_address: Mutex>, pub platform_config: RefCell>, } pub struct TestHandle { pub test: Rc, } impl MockDBus { pub async fn new() -> Result { let mut process = Command::new("/usr/bin/dbus-daemon") .args(["--session", "--nofork", "--print-address"]) .stdout(Stdio::piped()) .spawn()?; let stdout = BufReader::new( process .stdout .take() .ok_or(anyhow!("Couldn't capture stdout"))?, ); let address = stdout .lines() .next_line() .await? .ok_or(anyhow!("Failed to read address"))?; let address = Address::from_str(address.trim_end())?; let connection = ConnectionBuilder::address(address.clone())?.build().await?; Ok(MockDBus { connection, address, process, }) } pub fn shutdown(mut self) -> Result<()> { let pid = match self.process.id() { Some(id) => id, None => return Ok(()), }; let pid: pid_t = match pid.try_into() { Ok(pid) => pid, Err(message) => bail!("Unable to get pid_t from command {message}"), }; signal::kill(Pid::from_raw(pid), Signal::SIGINT)?; for _ in [0..10] { // Wait for the process to exit synchronously, but not for too long if self.process.try_wait()?.is_some() { break; } std::thread::sleep(Duration::from_micros(100)); } Ok(()) } } impl Test { pub fn path(&self) -> &Path { self.base.path() } } impl TestHandle { pub async fn new_dbus(&mut self) -> Result { let dbus = MockDBus::new().await?; let connection = dbus.connection.clone(); *self.test.dbus_address.lock().await = Some(dbus.address.clone()); self.test.mock_dbus.set(Some(dbus)); Ok(connection) } pub async fn dbus_address(&self) -> Option
{ (*self.test.dbus_address.lock().await).clone() } } impl Drop for TestHandle { fn drop(&mut self) { stop(); } } pub struct InterfaceIntrospection<'a> { interface: zbus_xml::Interface<'a>, } impl<'a> InterfaceIntrospection<'a> { pub async fn from_remote<'p, I, P>(connection: &Connection, path: P) -> Result where I: Interface, P: TryInto>, P::Error: Into, { let iface_ref = connection.object_server().interface::<_, I>(path).await?; let iface = iface_ref.get().await; let mut remote_interface_string = String::from( "", ); iface.introspect_to_writer(&mut remote_interface_string, 0); remote_interface_string.push_str(""); Self::from_xml(remote_interface_string.as_bytes(), I::name().to_string()) } pub async fn from_local<'p, P: AsRef, S: AsRef>( path: P, interface: S, ) -> Result { let local_interface_string = read(path.as_ref()).await?; Self::from_xml(local_interface_string.as_ref(), interface) } fn from_xml>(xml: &[u8], iface_name: S) -> Result { let node = Node::from_reader(xml)?; let interfaces = node.interfaces(); let mut interface = None; for iface in interfaces { if iface.name() == iface_name.as_ref() { interface = Some(iface.clone()); break; } } Ok(if let Some(interface) = interface { InterfaceIntrospection { interface } } else { bail!("No interface found"); }) } fn collect_methods(&self) -> HashMap> { let mut map = HashMap::new(); for method in self.interface.methods() { map.insert(method.name().to_string(), method); } map } fn collect_properties(&self) -> HashMap> { let mut map = HashMap::new(); for prop in self.interface.properties() { map.insert(prop.name().to_string(), prop); } map } fn compare_methods(&self, other: &InterfaceIntrospection<'_>) -> u32 { let local_methods = self.collect_methods(); let local_method_names: HashSet<&String> = local_methods.keys().collect(); let other_methods = other.collect_methods(); let other_method_names: HashSet<&String> = other_methods.keys().collect(); let mut issues = 0; for key in local_method_names.union(&other_method_names) { let Some(local_method) = local_methods.get(*key) else { error!("Method {key} missing on self"); issues += 1; continue; }; let Some(other_method) = other_methods.get(*key) else { error!("Method {key} missing on other"); issues += 1; continue; }; if local_method.args().len() != other_method.args().len() { error!("Different arguments between {local_method:?} and {other_method:?}"); issues += 1; continue; } for (local_arg, other_arg) in zip(local_method.args().iter(), other_method.args().iter()) { if local_arg.direction() != other_arg.direction() { error!("Arguments {local_arg:?} and {other_arg:?} differ in direction"); issues += 1; continue; } if local_arg.ty() != other_arg.ty() { error!("Arguments {local_arg:?} and {other_arg:?} differ in type"); issues += 1; continue; } } } issues } fn compare_properties(&self, other: &InterfaceIntrospection<'_>) -> u32 { let local_properties = self.collect_properties(); let local_property_names: HashSet<&String> = local_properties.keys().collect(); let other_properties = other.collect_properties(); let other_property_names: HashSet<&String> = other_properties.keys().collect(); let mut issues = 0; for key in local_property_names.union(&other_property_names) { let Some(local_property) = local_properties.get(*key) else { error!("Property {key} missing on self"); issues += 1; continue; }; let Some(other_property) = other_properties.get(*key) else { error!("Property {key} missing on other"); issues += 1; continue; }; if local_property.ty() != other_property.ty() { error!("Properties {local_property:?} and {other_property:?} differ in type"); issues += 1; continue; } if local_property.access() != other_property.access() { error!("Properties {local_property:?} and {other_property:?} differ in access"); issues += 1; continue; } } issues } pub fn compare(&self, other: &InterfaceIntrospection<'_>) -> bool { let mut issues = 0; issues += self.compare_methods(other); issues += self.compare_properties(other); issues == 0 } }