Path: blob/main/crates/bevy_ecs/src/schedule/graph/dag.rs
9368 views
use alloc::vec::Vec;1use core::{2fmt::{self, Debug},3hash::{BuildHasher, Hash},4ops::{Deref, DerefMut},5};67use bevy_platform::{8collections::{HashMap, HashSet},9hash::FixedHasher,10};11use fixedbitset::FixedBitSet;12use indexmap::IndexSet;13use thiserror::Error;1415use crate::{16error::Result,17schedule::graph::{18index, row_col, DiGraph, DiGraphToposortError,19Direction::{Incoming, Outgoing},20GraphNodeId, UnGraph,21},22};2324/// A directed acyclic graph structure.25#[derive(Clone)]26pub struct Dag<N: GraphNodeId, S: BuildHasher = FixedHasher> {27/// The underlying directed graph.28graph: DiGraph<N, S>,29/// A cached topological ordering of the graph. This is recomputed when the30/// graph is modified, and is not valid when `dirty` is true.31toposort: Vec<N>,32/// Whether the graph has been modified since the last topological sort.33dirty: bool,34}3536impl<N: GraphNodeId, S: BuildHasher> Dag<N, S> {37/// Creates a new directed acyclic graph.38pub fn new() -> Self39where40S: Default,41{42Self::default()43}4445/// Read-only access to the underlying directed graph.46#[must_use]47pub fn graph(&self) -> &DiGraph<N, S> {48&self.graph49}5051/// Mutable access to the underlying directed graph. Marks the graph as dirty.52#[must_use = "This function marks the graph as dirty, so it should be used."]53pub fn graph_mut(&mut self) -> &mut DiGraph<N, S> {54self.dirty = true;55&mut self.graph56}5758/// Returns whether the graph is dirty (i.e., has been modified since the59/// last topological sort).60#[must_use]61pub fn is_dirty(&self) -> bool {62self.dirty63}6465/// Returns whether the graph is topologically sorted (i.e., not dirty).66#[must_use]67pub fn is_toposorted(&self) -> bool {68!self.dirty69}7071/// Ensures the graph is topologically sorted, recomputing the toposort if72/// the graph is dirty.73///74/// # Errors75///76/// Returns [`DiGraphToposortError`] if the DAG is dirty and cannot be77/// topologically sorted.78pub fn ensure_toposorted(&mut self) -> Result<(), DiGraphToposortError<N>> {79if self.dirty {80// recompute the toposort, reusing the existing allocation81self.toposort = self.graph.toposort(core::mem::take(&mut self.toposort))?;82self.dirty = false;83}84Ok(())85}8687/// Returns the cached toposort if the graph is not dirty, otherwise returns88/// `None`.89#[must_use = "This method only returns a cached value and does not compute anything."]90pub fn get_toposort(&self) -> Option<&[N]> {91if self.dirty {92None93} else {94Some(&self.toposort)95}96}9798/// Returns a topological ordering of the graph, computing it if the graph99/// is dirty.100///101/// # Errors102///103/// Returns [`DiGraphToposortError`] if the DAG is dirty and cannot be104/// topologically sorted.105pub fn toposort(&mut self) -> Result<&[N], DiGraphToposortError<N>> {106self.ensure_toposorted()?;107Ok(&self.toposort)108}109110/// Returns both the topological ordering and the underlying graph,111/// computing the toposort if the graph is dirty.112///113/// This function is useful to avoid multiple borrow issues when both114/// the graph and the toposort are needed.115///116/// # Errors117///118/// Returns [`DiGraphToposortError`] if the DAG is dirty and cannot be119/// topologically sorted.120pub fn toposort_and_graph(121&mut self,122) -> Result<(&[N], &DiGraph<N, S>), DiGraphToposortError<N>> {123self.ensure_toposorted()?;124Ok((&self.toposort, &self.graph))125}126127/// Processes a DAG and computes various properties about it.128///129/// See [`DagAnalysis::new`] for details on what is computed.130///131/// # Note132///133/// If the DAG is dirty, this method will first attempt to topologically sort it.134///135/// # Errors136///137/// Returns [`DiGraphToposortError`] if the DAG is dirty and cannot be138/// topologically sorted.139///140pub fn analyze(&mut self) -> Result<DagAnalysis<N, S>, DiGraphToposortError<N>>141where142S: Default,143{144let (toposort, graph) = self.toposort_and_graph()?;145Ok(DagAnalysis::new(graph, toposort))146}147148/// Replaces the current graph with its transitive reduction based on the149/// provided analysis.150///151/// # Note152///153/// The given [`DagAnalysis`] must have been generated from this DAG.154pub fn remove_redundant_edges(&mut self, analysis: &DagAnalysis<N, S>)155where156S: Clone,157{158// We don't need to mark the graph as dirty, since transitive reduction159// is guaranteed to have the same topological ordering as the original graph.160self.graph = analysis.transitive_reduction.clone();161}162163/// Groups nodes in this DAG by a key type `K`, collecting value nodes `V`164/// under all of their ancestor key nodes. `num_groups` hints at the165/// expected number of groups, for memory allocation optimization.166///167/// The node type `N` must be convertible into either a key type `K` or168/// a value type `V` via the [`TryInto`] trait.169///170/// # Errors171///172/// Returns [`DiGraphToposortError`] if the DAG is dirty and cannot be173/// topologically sorted.174pub fn group_by_key<K, V>(175&mut self,176num_groups: usize,177) -> Result<DagGroups<K, V, S>, DiGraphToposortError<N>>178where179N: TryInto<K, Error = V>,180K: Eq + Hash,181V: Clone + Eq + Hash,182S: BuildHasher + Default,183{184let (toposort, graph) = self.toposort_and_graph()?;185Ok(DagGroups::with_capacity(num_groups, graph, toposort))186}187188/// Converts from one [`GraphNodeId`] type to another. If the conversion fails,189/// it returns the error from the target type's [`TryFrom`] implementation.190///191/// Nodes must uniquely convert from `N` to `T` (i.e. no two `N` can convert192/// to the same `T`). The resulting DAG must be re-topologically sorted.193///194/// # Errors195///196/// If the conversion fails, it returns an error of type `N::Error`.197pub fn try_convert<T>(self) -> Result<Dag<T, S>, N::Error>198where199N: TryInto<T>,200T: GraphNodeId,201S: Default,202{203Ok(Dag {204graph: self.graph.try_convert()?,205toposort: Vec::new(),206dirty: true,207})208}209}210211impl<N: GraphNodeId, S: BuildHasher> Deref for Dag<N, S> {212type Target = DiGraph<N, S>;213214fn deref(&self) -> &Self::Target {215self.graph()216}217}218219impl<N: GraphNodeId, S: BuildHasher> DerefMut for Dag<N, S> {220fn deref_mut(&mut self) -> &mut Self::Target {221self.graph_mut()222}223}224225impl<N: GraphNodeId, S: BuildHasher + Default> Default for Dag<N, S> {226fn default() -> Self {227Self {228graph: Default::default(),229toposort: Default::default(),230dirty: false,231}232}233}234235impl<N: GraphNodeId, S: BuildHasher> Debug for Dag<N, S> {236fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {237if self.dirty {238f.debug_struct("Dag")239.field("graph", &self.graph)240.field("dirty", &self.dirty)241.finish()242} else {243f.debug_struct("Dag")244.field("graph", &self.graph)245.field("toposort", &self.toposort)246.finish()247}248}249}250251/// Stores the results of a call to [`Dag::analyze`].252pub struct DagAnalysis<N: GraphNodeId, S: BuildHasher = FixedHasher> {253/// Boolean reachability matrix for the graph.254reachable: FixedBitSet,255/// Pairs of nodes that have a path connecting them.256connected: HashSet<(N, N), S>,257/// Pairs of nodes that don't have a path connecting them.258disconnected: Vec<(N, N)>,259/// Edges that are redundant because a longer path exists.260transitive_edges: Vec<(N, N)>,261/// Variant of the graph with no transitive edges.262transitive_reduction: DiGraph<N, S>,263/// Variant of the graph with all possible transitive edges.264transitive_closure: DiGraph<N, S>,265}266267impl<N: GraphNodeId, S: BuildHasher> DagAnalysis<N, S> {268/// Processes a DAG and computes its:269/// - transitive reduction (along with the set of removed edges)270/// - transitive closure271/// - reachability matrix (as a bitset)272/// - pairs of nodes connected by a path273/// - pairs of nodes not connected by a path274///275/// The algorithm implemented comes from276/// ["On the calculation of transitive reduction-closure of orders"][1] by Habib, Morvan and Rampon.277///278/// [1]: https://doi.org/10.1016/0012-365X(93)90164-O279pub fn new(graph: &DiGraph<N, S>, topological_order: &[N]) -> Self280where281S: Default,282{283if graph.node_count() == 0 {284return DagAnalysis::default();285}286let n = graph.node_count();287288// build a copy of the graph where the nodes and edges appear in topsorted order289let mut map = <HashMap<_, _>>::with_capacity_and_hasher(n, Default::default());290let mut topsorted =291DiGraph::<N>::with_capacity(topological_order.len(), graph.edge_count());292293// iterate nodes in topological order294for (i, &node) in topological_order.iter().enumerate() {295map.insert(node, i);296topsorted.add_node(node);297// insert nodes as successors to their predecessors298for pred in graph.neighbors_directed(node, Incoming) {299topsorted.add_edge(pred, node);300}301}302303let mut reachable = FixedBitSet::with_capacity(n * n);304let mut connected = HashSet::default();305let mut disconnected = Vec::default();306let mut transitive_edges = Vec::default();307let mut transitive_reduction = DiGraph::with_capacity(topsorted.node_count(), 0);308let mut transitive_closure = DiGraph::with_capacity(topsorted.node_count(), 0);309310let mut visited = FixedBitSet::with_capacity(n);311312// iterate nodes in topological order313for node in topsorted.nodes() {314transitive_reduction.add_node(node);315transitive_closure.add_node(node);316}317318// iterate nodes in reverse topological order319for a in topsorted.nodes().rev() {320let index_a = *map.get(&a).unwrap();321// iterate their successors in topological order322for b in topsorted.neighbors_directed(a, Outgoing) {323let index_b = *map.get(&b).unwrap();324debug_assert!(index_a < index_b);325if !visited[index_b] {326// edge <a, b> is not redundant327transitive_reduction.add_edge(a, b);328transitive_closure.add_edge(a, b);329reachable.insert(index(index_a, index_b, n));330331let successors = transitive_closure332.neighbors_directed(b, Outgoing)333.collect::<Vec<_>>();334for c in successors {335let index_c = *map.get(&c).unwrap();336debug_assert!(index_b < index_c);337if !visited[index_c] {338visited.insert(index_c);339transitive_closure.add_edge(a, c);340reachable.insert(index(index_a, index_c, n));341}342}343} else {344// edge <a, b> is redundant345transitive_edges.push((a, b));346}347}348349visited.clear();350}351352// partition pairs of nodes into "connected by path" and "not connected by path"353for i in 0..(n - 1) {354// reachable is upper triangular because the nodes were topsorted355for index in index(i, i + 1, n)..=index(i, n - 1, n) {356let (a, b) = row_col(index, n);357let pair = (topological_order[a], topological_order[b]);358if reachable[index] {359connected.insert(pair);360} else {361disconnected.push(pair);362}363}364}365366// fill diagonal (nodes reach themselves)367// for i in 0..n {368// reachable.set(index(i, i, n), true);369// }370371DagAnalysis {372reachable,373connected,374disconnected,375transitive_edges,376transitive_reduction,377transitive_closure,378}379}380381/// Returns the reachability matrix.382pub fn reachable(&self) -> &FixedBitSet {383&self.reachable384}385386/// Returns the set of node pairs that are connected by a path.387pub fn connected(&self) -> &HashSet<(N, N), S> {388&self.connected389}390391/// Returns the list of node pairs that are not connected by a path.392pub fn disconnected(&self) -> &[(N, N)] {393&self.disconnected394}395396/// Returns the list of redundant edges because a longer path exists.397pub fn transitive_edges(&self) -> &[(N, N)] {398&self.transitive_edges399}400401/// Returns the transitive reduction of the graph.402pub fn transitive_reduction(&self) -> &DiGraph<N, S> {403&self.transitive_reduction404}405406/// Returns the transitive closure of the graph.407pub fn transitive_closure(&self) -> &DiGraph<N, S> {408&self.transitive_closure409}410411/// Checks if the graph has any redundant (transitive) edges.412///413/// # Errors414///415/// If there are redundant edges, returns a [`DagRedundancyError`]416/// containing the list of redundant edges.417pub fn check_for_redundant_edges(&self) -> Result<(), DagRedundancyError<N>>418where419S: Clone,420{421if self.transitive_edges.is_empty() {422Ok(())423} else {424Err(DagRedundancyError(self.transitive_edges.clone()))425}426}427428/// Checks if there are any pairs of nodes that have a path in both this429/// graph and another graph.430///431/// # Errors432///433/// Returns [`DagCrossDependencyError`] if any node pair is connected in434/// both graphs.435pub fn check_for_cross_dependencies(436&self,437other: &Self,438) -> Result<(), DagCrossDependencyError<N>> {439for &(a, b) in &self.connected {440if other.connected.contains(&(a, b)) || other.connected.contains(&(b, a)) {441return Err(DagCrossDependencyError(a, b));442}443}444445Ok(())446}447448/// Checks if any connected node pairs that are both keys have overlapping449/// groups.450///451/// # Errors452///453/// If there are overlapping groups, returns a [`DagOverlappingGroupError`]454/// containing the first pair of keys that have overlapping groups.455pub fn check_for_overlapping_groups<K, V>(456&self,457groups: &DagGroups<K, V>,458) -> Result<(), DagOverlappingGroupError<K>>459where460N: TryInto<K>,461K: Eq + Hash,462V: Eq + Hash,463{464for &(a, b) in &self.connected {465let (Ok(a_key), Ok(b_key)) = (a.try_into(), b.try_into()) else {466continue;467};468let a_group = groups.get(&a_key).unwrap();469let b_group = groups.get(&b_key).unwrap();470if !a_group.is_disjoint(b_group) {471return Err(DagOverlappingGroupError(a_key, b_key));472}473}474Ok(())475}476}477478impl<N: GraphNodeId, S: BuildHasher + Default> Default for DagAnalysis<N, S> {479fn default() -> Self {480Self {481reachable: Default::default(),482connected: Default::default(),483disconnected: Default::default(),484transitive_edges: Default::default(),485transitive_reduction: Default::default(),486transitive_closure: Default::default(),487}488}489}490491impl<N: GraphNodeId, S: BuildHasher> Debug for DagAnalysis<N, S> {492fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {493f.debug_struct("DagAnalysis")494.field("reachable", &self.reachable)495.field("connected", &self.connected)496.field("disconnected", &self.disconnected)497.field("transitive_edges", &self.transitive_edges)498.field("transitive_reduction", &self.transitive_reduction)499.field("transitive_closure", &self.transitive_closure)500.finish()501}502}503504/// A mapping of keys to groups of values in a [`Dag`].505pub struct DagGroups<K, V, S = FixedHasher>(HashMap<K, IndexSet<V, S>, S>);506507impl<K: Eq + Hash, V: Clone + Eq + Hash, S: BuildHasher + Default> DagGroups<K, V, S> {508/// Groups nodes in this DAG by a key type `K`, collecting value nodes `V`509/// under all of their ancestor key nodes.510///511/// The node type `N` must be convertible into either a key type `K` or512/// a value type `V` via the [`TryInto`] trait.513pub fn new<N>(graph: &DiGraph<N, S>, toposort: &[N]) -> Self514where515N: GraphNodeId + TryInto<K, Error = V>,516{517Self::with_capacity(0, graph, toposort)518}519520/// Groups nodes in this DAG by a key type `K`, collecting value nodes `V`521/// under all of their ancestor key nodes. `capacity` hints at the522/// expected number of groups, for memory allocation optimization.523///524/// The node type `N` must be convertible into either a key type `K` or525/// a value type `V` via the [`TryInto`] trait.526pub fn with_capacity<N>(capacity: usize, graph: &DiGraph<N, S>, toposort: &[N]) -> Self527where528N: GraphNodeId + TryInto<K, Error = V>,529{530let mut groups: HashMap<K, IndexSet<V, S>, S> =531HashMap::with_capacity_and_hasher(capacity, Default::default());532533// Iterate in reverse topological order (bottom-up) so we hit children before parents.534for &id in toposort.iter().rev() {535let Ok(key) = id.try_into() else {536continue;537};538539let mut children = IndexSet::default();540541for node in graph.neighbors_directed(id, Outgoing) {542match node.try_into() {543Ok(key) => {544// If the child is a key, this key inherits all of its children.545let key_children = groups.get(&key).unwrap();546children.extend(key_children.iter().cloned());547}548Err(value) => {549// If the child is a value, add it directly.550children.insert(value);551}552}553}554555groups.insert(key, children);556}557558Self(groups)559}560}561562impl<K: GraphNodeId, V: GraphNodeId, S: BuildHasher> DagGroups<K, V, S> {563/// Converts the given [`Dag`] into a flattened version where key nodes564/// (`K`) are replaced by their associated value nodes (`V`). Edges to/from565/// key nodes are redirected to connect their value nodes instead.566///567/// The `collapse_group` function is called for each key node to customize568/// how its group is collapsed.569///570/// The resulting [`Dag`] will have only value nodes (`V`).571pub fn flatten<N>(572&self,573dag: Dag<N>,574mut collapse_group: impl FnMut(K, &IndexSet<V, S>, &Dag<N>, &mut Vec<(N, N)>),575) -> Dag<V>576where577N: GraphNodeId + TryInto<V, Error = K> + From<K> + From<V>,578{579let mut flattening = dag;580let mut temp = Vec::new();581582for (&key, values) in self.iter() {583// Call the user-provided function to handle collapsing the group.584collapse_group(key, values, &flattening, &mut temp);585586if values.is_empty() {587// Replace connections to the key node with connections between its neighbors.588for a in flattening.neighbors_directed(N::from(key), Incoming) {589for b in flattening.neighbors_directed(N::from(key), Outgoing) {590temp.push((a, b));591}592}593} else {594// Redirect edges to/from the key node to connect to its value nodes.595for a in flattening.neighbors_directed(N::from(key), Incoming) {596for &value in values {597temp.push((a, N::from(value)));598}599}600for b in flattening.neighbors_directed(N::from(key), Outgoing) {601for &value in values {602temp.push((N::from(value), b));603}604}605}606607// Remove the key node from the graph.608flattening.remove_node(N::from(key));609// Add all previously collected edges.610flattening.reserve_edges(temp.len());611for (a, b) in temp.drain(..) {612flattening.add_edge(a, b);613}614}615616// By this point, we should have removed all keys from the graph,617// so this conversion should never fail.618flattening619.try_convert::<V>()620.unwrap_or_else(|n| unreachable!("Flattened graph has a leftover key {n:?}"))621}622623/// Converts an undirected graph by replacing key nodes (`K`) with their624/// associated value nodes (`V`). Edges connected to key nodes are625/// redirected to connect their value nodes instead.626///627/// The resulting undirected graph will have only value nodes (`V`).628pub fn flatten_undirected<N>(&self, graph: &UnGraph<N>) -> UnGraph<V>629where630N: GraphNodeId + TryInto<V, Error = K>,631{632let mut flattened = UnGraph::default();633634for (lhs, rhs) in graph.all_edges() {635match (lhs.try_into(), rhs.try_into()) {636(Ok(lhs), Ok(rhs)) => {637// Normal edge between two value nodes638flattened.add_edge(lhs, rhs);639}640(Err(lhs_key), Ok(rhs)) => {641// Edge from a key node to a value node, expand to all values in the key's group642let Some(lhs_group) = self.get(&lhs_key) else {643continue;644};645flattened.reserve_edges(lhs_group.len());646for &lhs in lhs_group {647flattened.add_edge(lhs, rhs);648}649}650(Ok(lhs), Err(rhs_key)) => {651// Edge from a value node to a key node, expand to all values in the key's group652let Some(rhs_group) = self.get(&rhs_key) else {653continue;654};655flattened.reserve_edges(rhs_group.len());656for &rhs in rhs_group {657flattened.add_edge(lhs, rhs);658}659}660(Err(lhs_key), Err(rhs_key)) => {661// Edge between two key nodes, expand to all combinations of their value nodes662let Some(lhs_group) = self.get(&lhs_key) else {663continue;664};665let Some(rhs_group) = self.get(&rhs_key) else {666continue;667};668flattened.reserve_edges(lhs_group.len() * rhs_group.len());669for &lhs in lhs_group {670for &rhs in rhs_group {671flattened.add_edge(lhs, rhs);672}673}674}675}676}677678flattened679}680}681682impl<K, V, S> Deref for DagGroups<K, V, S> {683type Target = HashMap<K, IndexSet<V, S>, S>;684685fn deref(&self) -> &Self::Target {686&self.0687}688}689690impl<K, V, S> DerefMut for DagGroups<K, V, S> {691fn deref_mut(&mut self) -> &mut Self::Target {692&mut self.0693}694}695696impl<K, V, S> Default for DagGroups<K, V, S>697where698S: BuildHasher + Default,699{700fn default() -> Self {701Self(Default::default())702}703}704705impl<K: Debug, V: Debug, S> Debug for DagGroups<K, V, S> {706fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {707f.debug_tuple("DagGroups").field(&self.0).finish()708}709}710711/// Error indicating that the graph has redundant edges.712#[derive(Error, Debug)]713#[error("DAG has redundant edges: {0:?}")]714pub struct DagRedundancyError<N: GraphNodeId>(pub Vec<(N, N)>);715716/// Error indicating that two graphs both have a dependency between the same nodes.717#[derive(Error, Debug)]718#[error("DAG has a cross-dependency between nodes {0:?} and {1:?}")]719pub struct DagCrossDependencyError<N>(pub N, pub N);720721/// Error indicating that the graph has overlapping groups between two keys.722#[derive(Error, Debug)]723#[error("DAG has overlapping groups between keys {0:?} and {1:?}")]724pub struct DagOverlappingGroupError<K>(pub K, pub K);725726#[cfg(test)]727mod tests {728use core::ops::DerefMut;729730use crate::schedule::graph::{index, Dag, Direction, GraphNodeId, UnGraph};731732#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]733struct TestNode(u32);734735impl GraphNodeId for TestNode {736type Adjacent = (TestNode, Direction);737type Edge = (TestNode, TestNode);738739fn kind(&self) -> &'static str {740"test node"741}742}743744#[test]745fn mark_dirty() {746{747let mut dag = Dag::<TestNode>::new();748dag.add_node(TestNode(1));749assert!(dag.is_dirty());750}751{752let mut dag = Dag::<TestNode>::new();753dag.add_edge(TestNode(1), TestNode(2));754assert!(dag.is_dirty());755}756{757let mut dag = Dag::<TestNode>::new();758dag.deref_mut();759assert!(dag.is_dirty());760}761{762let mut dag = Dag::<TestNode>::new();763let _ = dag.graph_mut();764assert!(dag.is_dirty());765}766}767768#[test]769fn toposort() {770let mut dag = Dag::<TestNode>::new();771dag.add_edge(TestNode(1), TestNode(2));772dag.add_edge(TestNode(2), TestNode(3));773dag.add_edge(TestNode(1), TestNode(3));774775assert_eq!(776dag.toposort().unwrap(),777&[TestNode(1), TestNode(2), TestNode(3)]778);779assert_eq!(780dag.get_toposort().unwrap(),781&[TestNode(1), TestNode(2), TestNode(3)]782);783}784785#[test]786fn analyze() {787let mut dag1 = Dag::<TestNode>::new();788dag1.add_edge(TestNode(1), TestNode(2));789dag1.add_edge(TestNode(2), TestNode(3));790dag1.add_edge(TestNode(1), TestNode(3)); // redundant edge791792let analysis1 = dag1.analyze().unwrap();793794assert!(analysis1.reachable().contains(index(0, 1, 3)));795assert!(analysis1.reachable().contains(index(1, 2, 3)));796assert!(analysis1.reachable().contains(index(0, 2, 3)));797798assert!(analysis1.connected().contains(&(TestNode(1), TestNode(2))));799assert!(analysis1.connected().contains(&(TestNode(2), TestNode(3))));800assert!(analysis1.connected().contains(&(TestNode(1), TestNode(3))));801802assert!(!analysis1803.disconnected()804.contains(&(TestNode(2), TestNode(1))));805assert!(!analysis1806.disconnected()807.contains(&(TestNode(3), TestNode(2))));808assert!(!analysis1809.disconnected()810.contains(&(TestNode(3), TestNode(1))));811812assert!(analysis1813.transitive_edges()814.contains(&(TestNode(1), TestNode(3))));815816assert!(analysis1.check_for_redundant_edges().is_err());817818let mut dag2 = Dag::<TestNode>::new();819dag2.add_edge(TestNode(3), TestNode(4));820821let analysis2 = dag2.analyze().unwrap();822823assert!(analysis2.check_for_redundant_edges().is_ok());824assert!(analysis1.check_for_cross_dependencies(&analysis2).is_ok());825826let mut dag3 = Dag::<TestNode>::new();827dag3.add_edge(TestNode(1), TestNode(2));828829let analysis3 = dag3.analyze().unwrap();830831assert!(analysis1.check_for_cross_dependencies(&analysis3).is_err());832833dag1.remove_redundant_edges(&analysis1);834let analysis1 = dag1.analyze().unwrap();835assert!(analysis1.check_for_redundant_edges().is_ok());836}837838#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]839enum Node {840Key(Key),841Value(Value),842}843#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]844struct Key(u32);845#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]846struct Value(u32);847848impl GraphNodeId for Node {849type Adjacent = (Node, Direction);850type Edge = (Node, Node);851852fn kind(&self) -> &'static str {853"node"854}855}856857impl TryInto<Key> for Node {858type Error = Value;859860fn try_into(self) -> Result<Key, Value> {861match self {862Node::Key(k) => Ok(k),863Node::Value(v) => Err(v),864}865}866}867868impl TryInto<Value> for Node {869type Error = Key;870871fn try_into(self) -> Result<Value, Key> {872match self {873Node::Value(v) => Ok(v),874Node::Key(k) => Err(k),875}876}877}878879impl GraphNodeId for Key {880type Adjacent = (Key, Direction);881type Edge = (Key, Key);882883fn kind(&self) -> &'static str {884"key"885}886}887888impl GraphNodeId for Value {889type Adjacent = (Value, Direction);890type Edge = (Value, Value);891892fn kind(&self) -> &'static str {893"value"894}895}896897impl From<Key> for Node {898fn from(key: Key) -> Self {899Node::Key(key)900}901}902903impl From<Value> for Node {904fn from(value: Value) -> Self {905Node::Value(value)906}907}908909#[test]910fn group_by_key() {911let mut dag = Dag::<Node>::new();912dag.add_edge(Node::Key(Key(1)), Node::Value(Value(10)));913dag.add_edge(Node::Key(Key(1)), Node::Value(Value(11)));914dag.add_edge(Node::Key(Key(2)), Node::Value(Value(20)));915dag.add_edge(Node::Key(Key(2)), Node::Key(Key(1)));916dag.add_edge(Node::Value(Value(10)), Node::Value(Value(11)));917918let groups = dag.group_by_key::<Key, Value>(2).unwrap();919assert_eq!(groups.len(), 2);920921let group_key1 = groups.get(&Key(1)).unwrap();922assert!(group_key1.contains(&Value(10)));923assert!(group_key1.contains(&Value(11)));924925let group_key2 = groups.get(&Key(2)).unwrap();926assert!(group_key2.contains(&Value(10)));927assert!(group_key2.contains(&Value(11)));928assert!(group_key2.contains(&Value(20)));929}930931#[test]932fn flatten() {933let mut dag = Dag::<Node>::new();934dag.add_edge(Node::Key(Key(1)), Node::Value(Value(10)));935dag.add_edge(Node::Key(Key(1)), Node::Value(Value(11)));936dag.add_edge(Node::Key(Key(2)), Node::Value(Value(20)));937dag.add_edge(Node::Key(Key(2)), Node::Value(Value(21)));938dag.add_edge(Node::Value(Value(30)), Node::Key(Key(1)));939dag.add_edge(Node::Key(Key(1)), Node::Value(Value(40)));940941let groups = dag.group_by_key::<Key, Value>(2).unwrap();942let flattened = groups.flatten(dag, |_key, _values, _dag, _temp| {});943944assert!(flattened.contains_node(Value(10)));945assert!(flattened.contains_node(Value(11)));946assert!(flattened.contains_node(Value(20)));947assert!(flattened.contains_node(Value(21)));948assert!(flattened.contains_node(Value(30)));949assert!(flattened.contains_node(Value(40)));950951assert!(flattened.contains_edge(Value(30), Value(10)));952assert!(flattened.contains_edge(Value(30), Value(11)));953assert!(flattened.contains_edge(Value(10), Value(40)));954assert!(flattened.contains_edge(Value(11), Value(40)));955}956957#[test]958fn flatten_undirected() {959let mut dag = Dag::<Node>::new();960dag.add_edge(Node::Key(Key(1)), Node::Value(Value(10)));961dag.add_edge(Node::Key(Key(1)), Node::Value(Value(11)));962dag.add_edge(Node::Key(Key(2)), Node::Value(Value(20)));963dag.add_edge(Node::Key(Key(2)), Node::Value(Value(21)));964965let groups = dag.group_by_key::<Key, Value>(2).unwrap();966967let mut ungraph = UnGraph::<Node>::default();968ungraph.add_edge(Node::Value(Value(10)), Node::Value(Value(11)));969ungraph.add_edge(Node::Key(Key(1)), Node::Value(Value(30)));970ungraph.add_edge(Node::Value(Value(40)), Node::Key(Key(2)));971ungraph.add_edge(Node::Key(Key(1)), Node::Key(Key(2)));972973let flattened = groups.flatten_undirected(&ungraph);974975assert!(flattened.contains_edge(Value(10), Value(11)));976assert!(flattened.contains_edge(Value(10), Value(30)));977assert!(flattened.contains_edge(Value(11), Value(30)));978assert!(flattened.contains_edge(Value(40), Value(20)));979assert!(flattened.contains_edge(Value(40), Value(21)));980assert!(flattened.contains_edge(Value(10), Value(20)));981assert!(flattened.contains_edge(Value(10), Value(21)));982assert!(flattened.contains_edge(Value(11), Value(20)));983assert!(flattened.contains_edge(Value(11), Value(21)));984}985986#[test]987fn overlapping_groups() {988let mut dag = Dag::<Node>::new();989dag.add_edge(Node::Key(Key(1)), Node::Value(Value(10)));990dag.add_edge(Node::Key(Key(1)), Node::Value(Value(11)));991dag.add_edge(Node::Key(Key(2)), Node::Value(Value(11))); // overlap992dag.add_edge(Node::Key(Key(2)), Node::Value(Value(20)));993dag.add_edge(Node::Key(Key(1)), Node::Key(Key(2)));994995let groups = dag.group_by_key::<Key, Value>(2).unwrap();996let analysis = dag.analyze().unwrap();997998let result = analysis.check_for_overlapping_groups(&groups);999assert!(result.is_err());1000}10011002#[test]1003fn disjoint_groups() {1004let mut dag = Dag::<Node>::new();1005dag.add_edge(Node::Key(Key(1)), Node::Value(Value(10)));1006dag.add_edge(Node::Key(Key(1)), Node::Value(Value(11)));1007dag.add_edge(Node::Key(Key(2)), Node::Value(Value(20)));1008dag.add_edge(Node::Key(Key(2)), Node::Value(Value(21)));10091010let groups = dag.group_by_key::<Key, Value>(2).unwrap();1011let analysis = dag.analyze().unwrap();10121013let result = analysis.check_for_overlapping_groups(&groups);1014assert!(result.is_ok());1015}1016}101710181019