Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
bevyengine
GitHub Repository: bevyengine/bevy
Path: blob/main/crates/bevy_ecs/src/schedule/graph/dag.rs
9368 views
1
use alloc::vec::Vec;
2
use core::{
3
fmt::{self, Debug},
4
hash::{BuildHasher, Hash},
5
ops::{Deref, DerefMut},
6
};
7
8
use bevy_platform::{
9
collections::{HashMap, HashSet},
10
hash::FixedHasher,
11
};
12
use fixedbitset::FixedBitSet;
13
use indexmap::IndexSet;
14
use thiserror::Error;
15
16
use crate::{
17
error::Result,
18
schedule::graph::{
19
index, row_col, DiGraph, DiGraphToposortError,
20
Direction::{Incoming, Outgoing},
21
GraphNodeId, UnGraph,
22
},
23
};
24
25
/// A directed acyclic graph structure.
26
#[derive(Clone)]
27
pub struct Dag<N: GraphNodeId, S: BuildHasher = FixedHasher> {
28
/// The underlying directed graph.
29
graph: DiGraph<N, S>,
30
/// A cached topological ordering of the graph. This is recomputed when the
31
/// graph is modified, and is not valid when `dirty` is true.
32
toposort: Vec<N>,
33
/// Whether the graph has been modified since the last topological sort.
34
dirty: bool,
35
}
36
37
impl<N: GraphNodeId, S: BuildHasher> Dag<N, S> {
38
/// Creates a new directed acyclic graph.
39
pub fn new() -> Self
40
where
41
S: Default,
42
{
43
Self::default()
44
}
45
46
/// Read-only access to the underlying directed graph.
47
#[must_use]
48
pub fn graph(&self) -> &DiGraph<N, S> {
49
&self.graph
50
}
51
52
/// Mutable access to the underlying directed graph. Marks the graph as dirty.
53
#[must_use = "This function marks the graph as dirty, so it should be used."]
54
pub fn graph_mut(&mut self) -> &mut DiGraph<N, S> {
55
self.dirty = true;
56
&mut self.graph
57
}
58
59
/// Returns whether the graph is dirty (i.e., has been modified since the
60
/// last topological sort).
61
#[must_use]
62
pub fn is_dirty(&self) -> bool {
63
self.dirty
64
}
65
66
/// Returns whether the graph is topologically sorted (i.e., not dirty).
67
#[must_use]
68
pub fn is_toposorted(&self) -> bool {
69
!self.dirty
70
}
71
72
/// Ensures the graph is topologically sorted, recomputing the toposort if
73
/// the graph is dirty.
74
///
75
/// # Errors
76
///
77
/// Returns [`DiGraphToposortError`] if the DAG is dirty and cannot be
78
/// topologically sorted.
79
pub fn ensure_toposorted(&mut self) -> Result<(), DiGraphToposortError<N>> {
80
if self.dirty {
81
// recompute the toposort, reusing the existing allocation
82
self.toposort = self.graph.toposort(core::mem::take(&mut self.toposort))?;
83
self.dirty = false;
84
}
85
Ok(())
86
}
87
88
/// Returns the cached toposort if the graph is not dirty, otherwise returns
89
/// `None`.
90
#[must_use = "This method only returns a cached value and does not compute anything."]
91
pub fn get_toposort(&self) -> Option<&[N]> {
92
if self.dirty {
93
None
94
} else {
95
Some(&self.toposort)
96
}
97
}
98
99
/// Returns a topological ordering of the graph, computing it if the graph
100
/// is dirty.
101
///
102
/// # Errors
103
///
104
/// Returns [`DiGraphToposortError`] if the DAG is dirty and cannot be
105
/// topologically sorted.
106
pub fn toposort(&mut self) -> Result<&[N], DiGraphToposortError<N>> {
107
self.ensure_toposorted()?;
108
Ok(&self.toposort)
109
}
110
111
/// Returns both the topological ordering and the underlying graph,
112
/// computing the toposort if the graph is dirty.
113
///
114
/// This function is useful to avoid multiple borrow issues when both
115
/// the graph and the toposort are needed.
116
///
117
/// # Errors
118
///
119
/// Returns [`DiGraphToposortError`] if the DAG is dirty and cannot be
120
/// topologically sorted.
121
pub fn toposort_and_graph(
122
&mut self,
123
) -> Result<(&[N], &DiGraph<N, S>), DiGraphToposortError<N>> {
124
self.ensure_toposorted()?;
125
Ok((&self.toposort, &self.graph))
126
}
127
128
/// Processes a DAG and computes various properties about it.
129
///
130
/// See [`DagAnalysis::new`] for details on what is computed.
131
///
132
/// # Note
133
///
134
/// If the DAG is dirty, this method will first attempt to topologically sort it.
135
///
136
/// # Errors
137
///
138
/// Returns [`DiGraphToposortError`] if the DAG is dirty and cannot be
139
/// topologically sorted.
140
///
141
pub fn analyze(&mut self) -> Result<DagAnalysis<N, S>, DiGraphToposortError<N>>
142
where
143
S: Default,
144
{
145
let (toposort, graph) = self.toposort_and_graph()?;
146
Ok(DagAnalysis::new(graph, toposort))
147
}
148
149
/// Replaces the current graph with its transitive reduction based on the
150
/// provided analysis.
151
///
152
/// # Note
153
///
154
/// The given [`DagAnalysis`] must have been generated from this DAG.
155
pub fn remove_redundant_edges(&mut self, analysis: &DagAnalysis<N, S>)
156
where
157
S: Clone,
158
{
159
// We don't need to mark the graph as dirty, since transitive reduction
160
// is guaranteed to have the same topological ordering as the original graph.
161
self.graph = analysis.transitive_reduction.clone();
162
}
163
164
/// Groups nodes in this DAG by a key type `K`, collecting value nodes `V`
165
/// under all of their ancestor key nodes. `num_groups` hints at the
166
/// expected number of groups, for memory allocation optimization.
167
///
168
/// The node type `N` must be convertible into either a key type `K` or
169
/// a value type `V` via the [`TryInto`] trait.
170
///
171
/// # Errors
172
///
173
/// Returns [`DiGraphToposortError`] if the DAG is dirty and cannot be
174
/// topologically sorted.
175
pub fn group_by_key<K, V>(
176
&mut self,
177
num_groups: usize,
178
) -> Result<DagGroups<K, V, S>, DiGraphToposortError<N>>
179
where
180
N: TryInto<K, Error = V>,
181
K: Eq + Hash,
182
V: Clone + Eq + Hash,
183
S: BuildHasher + Default,
184
{
185
let (toposort, graph) = self.toposort_and_graph()?;
186
Ok(DagGroups::with_capacity(num_groups, graph, toposort))
187
}
188
189
/// Converts from one [`GraphNodeId`] type to another. If the conversion fails,
190
/// it returns the error from the target type's [`TryFrom`] implementation.
191
///
192
/// Nodes must uniquely convert from `N` to `T` (i.e. no two `N` can convert
193
/// to the same `T`). The resulting DAG must be re-topologically sorted.
194
///
195
/// # Errors
196
///
197
/// If the conversion fails, it returns an error of type `N::Error`.
198
pub fn try_convert<T>(self) -> Result<Dag<T, S>, N::Error>
199
where
200
N: TryInto<T>,
201
T: GraphNodeId,
202
S: Default,
203
{
204
Ok(Dag {
205
graph: self.graph.try_convert()?,
206
toposort: Vec::new(),
207
dirty: true,
208
})
209
}
210
}
211
212
impl<N: GraphNodeId, S: BuildHasher> Deref for Dag<N, S> {
213
type Target = DiGraph<N, S>;
214
215
fn deref(&self) -> &Self::Target {
216
self.graph()
217
}
218
}
219
220
impl<N: GraphNodeId, S: BuildHasher> DerefMut for Dag<N, S> {
221
fn deref_mut(&mut self) -> &mut Self::Target {
222
self.graph_mut()
223
}
224
}
225
226
impl<N: GraphNodeId, S: BuildHasher + Default> Default for Dag<N, S> {
227
fn default() -> Self {
228
Self {
229
graph: Default::default(),
230
toposort: Default::default(),
231
dirty: false,
232
}
233
}
234
}
235
236
impl<N: GraphNodeId, S: BuildHasher> Debug for Dag<N, S> {
237
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
238
if self.dirty {
239
f.debug_struct("Dag")
240
.field("graph", &self.graph)
241
.field("dirty", &self.dirty)
242
.finish()
243
} else {
244
f.debug_struct("Dag")
245
.field("graph", &self.graph)
246
.field("toposort", &self.toposort)
247
.finish()
248
}
249
}
250
}
251
252
/// Stores the results of a call to [`Dag::analyze`].
253
pub struct DagAnalysis<N: GraphNodeId, S: BuildHasher = FixedHasher> {
254
/// Boolean reachability matrix for the graph.
255
reachable: FixedBitSet,
256
/// Pairs of nodes that have a path connecting them.
257
connected: HashSet<(N, N), S>,
258
/// Pairs of nodes that don't have a path connecting them.
259
disconnected: Vec<(N, N)>,
260
/// Edges that are redundant because a longer path exists.
261
transitive_edges: Vec<(N, N)>,
262
/// Variant of the graph with no transitive edges.
263
transitive_reduction: DiGraph<N, S>,
264
/// Variant of the graph with all possible transitive edges.
265
transitive_closure: DiGraph<N, S>,
266
}
267
268
impl<N: GraphNodeId, S: BuildHasher> DagAnalysis<N, S> {
269
/// Processes a DAG and computes its:
270
/// - transitive reduction (along with the set of removed edges)
271
/// - transitive closure
272
/// - reachability matrix (as a bitset)
273
/// - pairs of nodes connected by a path
274
/// - pairs of nodes not connected by a path
275
///
276
/// The algorithm implemented comes from
277
/// ["On the calculation of transitive reduction-closure of orders"][1] by Habib, Morvan and Rampon.
278
///
279
/// [1]: https://doi.org/10.1016/0012-365X(93)90164-O
280
pub fn new(graph: &DiGraph<N, S>, topological_order: &[N]) -> Self
281
where
282
S: Default,
283
{
284
if graph.node_count() == 0 {
285
return DagAnalysis::default();
286
}
287
let n = graph.node_count();
288
289
// build a copy of the graph where the nodes and edges appear in topsorted order
290
let mut map = <HashMap<_, _>>::with_capacity_and_hasher(n, Default::default());
291
let mut topsorted =
292
DiGraph::<N>::with_capacity(topological_order.len(), graph.edge_count());
293
294
// iterate nodes in topological order
295
for (i, &node) in topological_order.iter().enumerate() {
296
map.insert(node, i);
297
topsorted.add_node(node);
298
// insert nodes as successors to their predecessors
299
for pred in graph.neighbors_directed(node, Incoming) {
300
topsorted.add_edge(pred, node);
301
}
302
}
303
304
let mut reachable = FixedBitSet::with_capacity(n * n);
305
let mut connected = HashSet::default();
306
let mut disconnected = Vec::default();
307
let mut transitive_edges = Vec::default();
308
let mut transitive_reduction = DiGraph::with_capacity(topsorted.node_count(), 0);
309
let mut transitive_closure = DiGraph::with_capacity(topsorted.node_count(), 0);
310
311
let mut visited = FixedBitSet::with_capacity(n);
312
313
// iterate nodes in topological order
314
for node in topsorted.nodes() {
315
transitive_reduction.add_node(node);
316
transitive_closure.add_node(node);
317
}
318
319
// iterate nodes in reverse topological order
320
for a in topsorted.nodes().rev() {
321
let index_a = *map.get(&a).unwrap();
322
// iterate their successors in topological order
323
for b in topsorted.neighbors_directed(a, Outgoing) {
324
let index_b = *map.get(&b).unwrap();
325
debug_assert!(index_a < index_b);
326
if !visited[index_b] {
327
// edge <a, b> is not redundant
328
transitive_reduction.add_edge(a, b);
329
transitive_closure.add_edge(a, b);
330
reachable.insert(index(index_a, index_b, n));
331
332
let successors = transitive_closure
333
.neighbors_directed(b, Outgoing)
334
.collect::<Vec<_>>();
335
for c in successors {
336
let index_c = *map.get(&c).unwrap();
337
debug_assert!(index_b < index_c);
338
if !visited[index_c] {
339
visited.insert(index_c);
340
transitive_closure.add_edge(a, c);
341
reachable.insert(index(index_a, index_c, n));
342
}
343
}
344
} else {
345
// edge <a, b> is redundant
346
transitive_edges.push((a, b));
347
}
348
}
349
350
visited.clear();
351
}
352
353
// partition pairs of nodes into "connected by path" and "not connected by path"
354
for i in 0..(n - 1) {
355
// reachable is upper triangular because the nodes were topsorted
356
for index in index(i, i + 1, n)..=index(i, n - 1, n) {
357
let (a, b) = row_col(index, n);
358
let pair = (topological_order[a], topological_order[b]);
359
if reachable[index] {
360
connected.insert(pair);
361
} else {
362
disconnected.push(pair);
363
}
364
}
365
}
366
367
// fill diagonal (nodes reach themselves)
368
// for i in 0..n {
369
// reachable.set(index(i, i, n), true);
370
// }
371
372
DagAnalysis {
373
reachable,
374
connected,
375
disconnected,
376
transitive_edges,
377
transitive_reduction,
378
transitive_closure,
379
}
380
}
381
382
/// Returns the reachability matrix.
383
pub fn reachable(&self) -> &FixedBitSet {
384
&self.reachable
385
}
386
387
/// Returns the set of node pairs that are connected by a path.
388
pub fn connected(&self) -> &HashSet<(N, N), S> {
389
&self.connected
390
}
391
392
/// Returns the list of node pairs that are not connected by a path.
393
pub fn disconnected(&self) -> &[(N, N)] {
394
&self.disconnected
395
}
396
397
/// Returns the list of redundant edges because a longer path exists.
398
pub fn transitive_edges(&self) -> &[(N, N)] {
399
&self.transitive_edges
400
}
401
402
/// Returns the transitive reduction of the graph.
403
pub fn transitive_reduction(&self) -> &DiGraph<N, S> {
404
&self.transitive_reduction
405
}
406
407
/// Returns the transitive closure of the graph.
408
pub fn transitive_closure(&self) -> &DiGraph<N, S> {
409
&self.transitive_closure
410
}
411
412
/// Checks if the graph has any redundant (transitive) edges.
413
///
414
/// # Errors
415
///
416
/// If there are redundant edges, returns a [`DagRedundancyError`]
417
/// containing the list of redundant edges.
418
pub fn check_for_redundant_edges(&self) -> Result<(), DagRedundancyError<N>>
419
where
420
S: Clone,
421
{
422
if self.transitive_edges.is_empty() {
423
Ok(())
424
} else {
425
Err(DagRedundancyError(self.transitive_edges.clone()))
426
}
427
}
428
429
/// Checks if there are any pairs of nodes that have a path in both this
430
/// graph and another graph.
431
///
432
/// # Errors
433
///
434
/// Returns [`DagCrossDependencyError`] if any node pair is connected in
435
/// both graphs.
436
pub fn check_for_cross_dependencies(
437
&self,
438
other: &Self,
439
) -> Result<(), DagCrossDependencyError<N>> {
440
for &(a, b) in &self.connected {
441
if other.connected.contains(&(a, b)) || other.connected.contains(&(b, a)) {
442
return Err(DagCrossDependencyError(a, b));
443
}
444
}
445
446
Ok(())
447
}
448
449
/// Checks if any connected node pairs that are both keys have overlapping
450
/// groups.
451
///
452
/// # Errors
453
///
454
/// If there are overlapping groups, returns a [`DagOverlappingGroupError`]
455
/// containing the first pair of keys that have overlapping groups.
456
pub fn check_for_overlapping_groups<K, V>(
457
&self,
458
groups: &DagGroups<K, V>,
459
) -> Result<(), DagOverlappingGroupError<K>>
460
where
461
N: TryInto<K>,
462
K: Eq + Hash,
463
V: Eq + Hash,
464
{
465
for &(a, b) in &self.connected {
466
let (Ok(a_key), Ok(b_key)) = (a.try_into(), b.try_into()) else {
467
continue;
468
};
469
let a_group = groups.get(&a_key).unwrap();
470
let b_group = groups.get(&b_key).unwrap();
471
if !a_group.is_disjoint(b_group) {
472
return Err(DagOverlappingGroupError(a_key, b_key));
473
}
474
}
475
Ok(())
476
}
477
}
478
479
impl<N: GraphNodeId, S: BuildHasher + Default> Default for DagAnalysis<N, S> {
480
fn default() -> Self {
481
Self {
482
reachable: Default::default(),
483
connected: Default::default(),
484
disconnected: Default::default(),
485
transitive_edges: Default::default(),
486
transitive_reduction: Default::default(),
487
transitive_closure: Default::default(),
488
}
489
}
490
}
491
492
impl<N: GraphNodeId, S: BuildHasher> Debug for DagAnalysis<N, S> {
493
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
494
f.debug_struct("DagAnalysis")
495
.field("reachable", &self.reachable)
496
.field("connected", &self.connected)
497
.field("disconnected", &self.disconnected)
498
.field("transitive_edges", &self.transitive_edges)
499
.field("transitive_reduction", &self.transitive_reduction)
500
.field("transitive_closure", &self.transitive_closure)
501
.finish()
502
}
503
}
504
505
/// A mapping of keys to groups of values in a [`Dag`].
506
pub struct DagGroups<K, V, S = FixedHasher>(HashMap<K, IndexSet<V, S>, S>);
507
508
impl<K: Eq + Hash, V: Clone + Eq + Hash, S: BuildHasher + Default> DagGroups<K, V, S> {
509
/// Groups nodes in this DAG by a key type `K`, collecting value nodes `V`
510
/// under all of their ancestor key nodes.
511
///
512
/// The node type `N` must be convertible into either a key type `K` or
513
/// a value type `V` via the [`TryInto`] trait.
514
pub fn new<N>(graph: &DiGraph<N, S>, toposort: &[N]) -> Self
515
where
516
N: GraphNodeId + TryInto<K, Error = V>,
517
{
518
Self::with_capacity(0, graph, toposort)
519
}
520
521
/// Groups nodes in this DAG by a key type `K`, collecting value nodes `V`
522
/// under all of their ancestor key nodes. `capacity` hints at the
523
/// expected number of groups, for memory allocation optimization.
524
///
525
/// The node type `N` must be convertible into either a key type `K` or
526
/// a value type `V` via the [`TryInto`] trait.
527
pub fn with_capacity<N>(capacity: usize, graph: &DiGraph<N, S>, toposort: &[N]) -> Self
528
where
529
N: GraphNodeId + TryInto<K, Error = V>,
530
{
531
let mut groups: HashMap<K, IndexSet<V, S>, S> =
532
HashMap::with_capacity_and_hasher(capacity, Default::default());
533
534
// Iterate in reverse topological order (bottom-up) so we hit children before parents.
535
for &id in toposort.iter().rev() {
536
let Ok(key) = id.try_into() else {
537
continue;
538
};
539
540
let mut children = IndexSet::default();
541
542
for node in graph.neighbors_directed(id, Outgoing) {
543
match node.try_into() {
544
Ok(key) => {
545
// If the child is a key, this key inherits all of its children.
546
let key_children = groups.get(&key).unwrap();
547
children.extend(key_children.iter().cloned());
548
}
549
Err(value) => {
550
// If the child is a value, add it directly.
551
children.insert(value);
552
}
553
}
554
}
555
556
groups.insert(key, children);
557
}
558
559
Self(groups)
560
}
561
}
562
563
impl<K: GraphNodeId, V: GraphNodeId, S: BuildHasher> DagGroups<K, V, S> {
564
/// Converts the given [`Dag`] into a flattened version where key nodes
565
/// (`K`) are replaced by their associated value nodes (`V`). Edges to/from
566
/// key nodes are redirected to connect their value nodes instead.
567
///
568
/// The `collapse_group` function is called for each key node to customize
569
/// how its group is collapsed.
570
///
571
/// The resulting [`Dag`] will have only value nodes (`V`).
572
pub fn flatten<N>(
573
&self,
574
dag: Dag<N>,
575
mut collapse_group: impl FnMut(K, &IndexSet<V, S>, &Dag<N>, &mut Vec<(N, N)>),
576
) -> Dag<V>
577
where
578
N: GraphNodeId + TryInto<V, Error = K> + From<K> + From<V>,
579
{
580
let mut flattening = dag;
581
let mut temp = Vec::new();
582
583
for (&key, values) in self.iter() {
584
// Call the user-provided function to handle collapsing the group.
585
collapse_group(key, values, &flattening, &mut temp);
586
587
if values.is_empty() {
588
// Replace connections to the key node with connections between its neighbors.
589
for a in flattening.neighbors_directed(N::from(key), Incoming) {
590
for b in flattening.neighbors_directed(N::from(key), Outgoing) {
591
temp.push((a, b));
592
}
593
}
594
} else {
595
// Redirect edges to/from the key node to connect to its value nodes.
596
for a in flattening.neighbors_directed(N::from(key), Incoming) {
597
for &value in values {
598
temp.push((a, N::from(value)));
599
}
600
}
601
for b in flattening.neighbors_directed(N::from(key), Outgoing) {
602
for &value in values {
603
temp.push((N::from(value), b));
604
}
605
}
606
}
607
608
// Remove the key node from the graph.
609
flattening.remove_node(N::from(key));
610
// Add all previously collected edges.
611
flattening.reserve_edges(temp.len());
612
for (a, b) in temp.drain(..) {
613
flattening.add_edge(a, b);
614
}
615
}
616
617
// By this point, we should have removed all keys from the graph,
618
// so this conversion should never fail.
619
flattening
620
.try_convert::<V>()
621
.unwrap_or_else(|n| unreachable!("Flattened graph has a leftover key {n:?}"))
622
}
623
624
/// Converts an undirected graph by replacing key nodes (`K`) with their
625
/// associated value nodes (`V`). Edges connected to key nodes are
626
/// redirected to connect their value nodes instead.
627
///
628
/// The resulting undirected graph will have only value nodes (`V`).
629
pub fn flatten_undirected<N>(&self, graph: &UnGraph<N>) -> UnGraph<V>
630
where
631
N: GraphNodeId + TryInto<V, Error = K>,
632
{
633
let mut flattened = UnGraph::default();
634
635
for (lhs, rhs) in graph.all_edges() {
636
match (lhs.try_into(), rhs.try_into()) {
637
(Ok(lhs), Ok(rhs)) => {
638
// Normal edge between two value nodes
639
flattened.add_edge(lhs, rhs);
640
}
641
(Err(lhs_key), Ok(rhs)) => {
642
// Edge from a key node to a value node, expand to all values in the key's group
643
let Some(lhs_group) = self.get(&lhs_key) else {
644
continue;
645
};
646
flattened.reserve_edges(lhs_group.len());
647
for &lhs in lhs_group {
648
flattened.add_edge(lhs, rhs);
649
}
650
}
651
(Ok(lhs), Err(rhs_key)) => {
652
// Edge from a value node to a key node, expand to all values in the key's group
653
let Some(rhs_group) = self.get(&rhs_key) else {
654
continue;
655
};
656
flattened.reserve_edges(rhs_group.len());
657
for &rhs in rhs_group {
658
flattened.add_edge(lhs, rhs);
659
}
660
}
661
(Err(lhs_key), Err(rhs_key)) => {
662
// Edge between two key nodes, expand to all combinations of their value nodes
663
let Some(lhs_group) = self.get(&lhs_key) else {
664
continue;
665
};
666
let Some(rhs_group) = self.get(&rhs_key) else {
667
continue;
668
};
669
flattened.reserve_edges(lhs_group.len() * rhs_group.len());
670
for &lhs in lhs_group {
671
for &rhs in rhs_group {
672
flattened.add_edge(lhs, rhs);
673
}
674
}
675
}
676
}
677
}
678
679
flattened
680
}
681
}
682
683
impl<K, V, S> Deref for DagGroups<K, V, S> {
684
type Target = HashMap<K, IndexSet<V, S>, S>;
685
686
fn deref(&self) -> &Self::Target {
687
&self.0
688
}
689
}
690
691
impl<K, V, S> DerefMut for DagGroups<K, V, S> {
692
fn deref_mut(&mut self) -> &mut Self::Target {
693
&mut self.0
694
}
695
}
696
697
impl<K, V, S> Default for DagGroups<K, V, S>
698
where
699
S: BuildHasher + Default,
700
{
701
fn default() -> Self {
702
Self(Default::default())
703
}
704
}
705
706
impl<K: Debug, V: Debug, S> Debug for DagGroups<K, V, S> {
707
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
708
f.debug_tuple("DagGroups").field(&self.0).finish()
709
}
710
}
711
712
/// Error indicating that the graph has redundant edges.
713
#[derive(Error, Debug)]
714
#[error("DAG has redundant edges: {0:?}")]
715
pub struct DagRedundancyError<N: GraphNodeId>(pub Vec<(N, N)>);
716
717
/// Error indicating that two graphs both have a dependency between the same nodes.
718
#[derive(Error, Debug)]
719
#[error("DAG has a cross-dependency between nodes {0:?} and {1:?}")]
720
pub struct DagCrossDependencyError<N>(pub N, pub N);
721
722
/// Error indicating that the graph has overlapping groups between two keys.
723
#[derive(Error, Debug)]
724
#[error("DAG has overlapping groups between keys {0:?} and {1:?}")]
725
pub struct DagOverlappingGroupError<K>(pub K, pub K);
726
727
#[cfg(test)]
728
mod tests {
729
use core::ops::DerefMut;
730
731
use crate::schedule::graph::{index, Dag, Direction, GraphNodeId, UnGraph};
732
733
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
734
struct TestNode(u32);
735
736
impl GraphNodeId for TestNode {
737
type Adjacent = (TestNode, Direction);
738
type Edge = (TestNode, TestNode);
739
740
fn kind(&self) -> &'static str {
741
"test node"
742
}
743
}
744
745
#[test]
746
fn mark_dirty() {
747
{
748
let mut dag = Dag::<TestNode>::new();
749
dag.add_node(TestNode(1));
750
assert!(dag.is_dirty());
751
}
752
{
753
let mut dag = Dag::<TestNode>::new();
754
dag.add_edge(TestNode(1), TestNode(2));
755
assert!(dag.is_dirty());
756
}
757
{
758
let mut dag = Dag::<TestNode>::new();
759
dag.deref_mut();
760
assert!(dag.is_dirty());
761
}
762
{
763
let mut dag = Dag::<TestNode>::new();
764
let _ = dag.graph_mut();
765
assert!(dag.is_dirty());
766
}
767
}
768
769
#[test]
770
fn toposort() {
771
let mut dag = Dag::<TestNode>::new();
772
dag.add_edge(TestNode(1), TestNode(2));
773
dag.add_edge(TestNode(2), TestNode(3));
774
dag.add_edge(TestNode(1), TestNode(3));
775
776
assert_eq!(
777
dag.toposort().unwrap(),
778
&[TestNode(1), TestNode(2), TestNode(3)]
779
);
780
assert_eq!(
781
dag.get_toposort().unwrap(),
782
&[TestNode(1), TestNode(2), TestNode(3)]
783
);
784
}
785
786
#[test]
787
fn analyze() {
788
let mut dag1 = Dag::<TestNode>::new();
789
dag1.add_edge(TestNode(1), TestNode(2));
790
dag1.add_edge(TestNode(2), TestNode(3));
791
dag1.add_edge(TestNode(1), TestNode(3)); // redundant edge
792
793
let analysis1 = dag1.analyze().unwrap();
794
795
assert!(analysis1.reachable().contains(index(0, 1, 3)));
796
assert!(analysis1.reachable().contains(index(1, 2, 3)));
797
assert!(analysis1.reachable().contains(index(0, 2, 3)));
798
799
assert!(analysis1.connected().contains(&(TestNode(1), TestNode(2))));
800
assert!(analysis1.connected().contains(&(TestNode(2), TestNode(3))));
801
assert!(analysis1.connected().contains(&(TestNode(1), TestNode(3))));
802
803
assert!(!analysis1
804
.disconnected()
805
.contains(&(TestNode(2), TestNode(1))));
806
assert!(!analysis1
807
.disconnected()
808
.contains(&(TestNode(3), TestNode(2))));
809
assert!(!analysis1
810
.disconnected()
811
.contains(&(TestNode(3), TestNode(1))));
812
813
assert!(analysis1
814
.transitive_edges()
815
.contains(&(TestNode(1), TestNode(3))));
816
817
assert!(analysis1.check_for_redundant_edges().is_err());
818
819
let mut dag2 = Dag::<TestNode>::new();
820
dag2.add_edge(TestNode(3), TestNode(4));
821
822
let analysis2 = dag2.analyze().unwrap();
823
824
assert!(analysis2.check_for_redundant_edges().is_ok());
825
assert!(analysis1.check_for_cross_dependencies(&analysis2).is_ok());
826
827
let mut dag3 = Dag::<TestNode>::new();
828
dag3.add_edge(TestNode(1), TestNode(2));
829
830
let analysis3 = dag3.analyze().unwrap();
831
832
assert!(analysis1.check_for_cross_dependencies(&analysis3).is_err());
833
834
dag1.remove_redundant_edges(&analysis1);
835
let analysis1 = dag1.analyze().unwrap();
836
assert!(analysis1.check_for_redundant_edges().is_ok());
837
}
838
839
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
840
enum Node {
841
Key(Key),
842
Value(Value),
843
}
844
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
845
struct Key(u32);
846
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
847
struct Value(u32);
848
849
impl GraphNodeId for Node {
850
type Adjacent = (Node, Direction);
851
type Edge = (Node, Node);
852
853
fn kind(&self) -> &'static str {
854
"node"
855
}
856
}
857
858
impl TryInto<Key> for Node {
859
type Error = Value;
860
861
fn try_into(self) -> Result<Key, Value> {
862
match self {
863
Node::Key(k) => Ok(k),
864
Node::Value(v) => Err(v),
865
}
866
}
867
}
868
869
impl TryInto<Value> for Node {
870
type Error = Key;
871
872
fn try_into(self) -> Result<Value, Key> {
873
match self {
874
Node::Value(v) => Ok(v),
875
Node::Key(k) => Err(k),
876
}
877
}
878
}
879
880
impl GraphNodeId for Key {
881
type Adjacent = (Key, Direction);
882
type Edge = (Key, Key);
883
884
fn kind(&self) -> &'static str {
885
"key"
886
}
887
}
888
889
impl GraphNodeId for Value {
890
type Adjacent = (Value, Direction);
891
type Edge = (Value, Value);
892
893
fn kind(&self) -> &'static str {
894
"value"
895
}
896
}
897
898
impl From<Key> for Node {
899
fn from(key: Key) -> Self {
900
Node::Key(key)
901
}
902
}
903
904
impl From<Value> for Node {
905
fn from(value: Value) -> Self {
906
Node::Value(value)
907
}
908
}
909
910
#[test]
911
fn group_by_key() {
912
let mut dag = Dag::<Node>::new();
913
dag.add_edge(Node::Key(Key(1)), Node::Value(Value(10)));
914
dag.add_edge(Node::Key(Key(1)), Node::Value(Value(11)));
915
dag.add_edge(Node::Key(Key(2)), Node::Value(Value(20)));
916
dag.add_edge(Node::Key(Key(2)), Node::Key(Key(1)));
917
dag.add_edge(Node::Value(Value(10)), Node::Value(Value(11)));
918
919
let groups = dag.group_by_key::<Key, Value>(2).unwrap();
920
assert_eq!(groups.len(), 2);
921
922
let group_key1 = groups.get(&Key(1)).unwrap();
923
assert!(group_key1.contains(&Value(10)));
924
assert!(group_key1.contains(&Value(11)));
925
926
let group_key2 = groups.get(&Key(2)).unwrap();
927
assert!(group_key2.contains(&Value(10)));
928
assert!(group_key2.contains(&Value(11)));
929
assert!(group_key2.contains(&Value(20)));
930
}
931
932
#[test]
933
fn flatten() {
934
let mut dag = Dag::<Node>::new();
935
dag.add_edge(Node::Key(Key(1)), Node::Value(Value(10)));
936
dag.add_edge(Node::Key(Key(1)), Node::Value(Value(11)));
937
dag.add_edge(Node::Key(Key(2)), Node::Value(Value(20)));
938
dag.add_edge(Node::Key(Key(2)), Node::Value(Value(21)));
939
dag.add_edge(Node::Value(Value(30)), Node::Key(Key(1)));
940
dag.add_edge(Node::Key(Key(1)), Node::Value(Value(40)));
941
942
let groups = dag.group_by_key::<Key, Value>(2).unwrap();
943
let flattened = groups.flatten(dag, |_key, _values, _dag, _temp| {});
944
945
assert!(flattened.contains_node(Value(10)));
946
assert!(flattened.contains_node(Value(11)));
947
assert!(flattened.contains_node(Value(20)));
948
assert!(flattened.contains_node(Value(21)));
949
assert!(flattened.contains_node(Value(30)));
950
assert!(flattened.contains_node(Value(40)));
951
952
assert!(flattened.contains_edge(Value(30), Value(10)));
953
assert!(flattened.contains_edge(Value(30), Value(11)));
954
assert!(flattened.contains_edge(Value(10), Value(40)));
955
assert!(flattened.contains_edge(Value(11), Value(40)));
956
}
957
958
#[test]
959
fn flatten_undirected() {
960
let mut dag = Dag::<Node>::new();
961
dag.add_edge(Node::Key(Key(1)), Node::Value(Value(10)));
962
dag.add_edge(Node::Key(Key(1)), Node::Value(Value(11)));
963
dag.add_edge(Node::Key(Key(2)), Node::Value(Value(20)));
964
dag.add_edge(Node::Key(Key(2)), Node::Value(Value(21)));
965
966
let groups = dag.group_by_key::<Key, Value>(2).unwrap();
967
968
let mut ungraph = UnGraph::<Node>::default();
969
ungraph.add_edge(Node::Value(Value(10)), Node::Value(Value(11)));
970
ungraph.add_edge(Node::Key(Key(1)), Node::Value(Value(30)));
971
ungraph.add_edge(Node::Value(Value(40)), Node::Key(Key(2)));
972
ungraph.add_edge(Node::Key(Key(1)), Node::Key(Key(2)));
973
974
let flattened = groups.flatten_undirected(&ungraph);
975
976
assert!(flattened.contains_edge(Value(10), Value(11)));
977
assert!(flattened.contains_edge(Value(10), Value(30)));
978
assert!(flattened.contains_edge(Value(11), Value(30)));
979
assert!(flattened.contains_edge(Value(40), Value(20)));
980
assert!(flattened.contains_edge(Value(40), Value(21)));
981
assert!(flattened.contains_edge(Value(10), Value(20)));
982
assert!(flattened.contains_edge(Value(10), Value(21)));
983
assert!(flattened.contains_edge(Value(11), Value(20)));
984
assert!(flattened.contains_edge(Value(11), Value(21)));
985
}
986
987
#[test]
988
fn overlapping_groups() {
989
let mut dag = Dag::<Node>::new();
990
dag.add_edge(Node::Key(Key(1)), Node::Value(Value(10)));
991
dag.add_edge(Node::Key(Key(1)), Node::Value(Value(11)));
992
dag.add_edge(Node::Key(Key(2)), Node::Value(Value(11))); // overlap
993
dag.add_edge(Node::Key(Key(2)), Node::Value(Value(20)));
994
dag.add_edge(Node::Key(Key(1)), Node::Key(Key(2)));
995
996
let groups = dag.group_by_key::<Key, Value>(2).unwrap();
997
let analysis = dag.analyze().unwrap();
998
999
let result = analysis.check_for_overlapping_groups(&groups);
1000
assert!(result.is_err());
1001
}
1002
1003
#[test]
1004
fn disjoint_groups() {
1005
let mut dag = Dag::<Node>::new();
1006
dag.add_edge(Node::Key(Key(1)), Node::Value(Value(10)));
1007
dag.add_edge(Node::Key(Key(1)), Node::Value(Value(11)));
1008
dag.add_edge(Node::Key(Key(2)), Node::Value(Value(20)));
1009
dag.add_edge(Node::Key(Key(2)), Node::Value(Value(21)));
1010
1011
let groups = dag.group_by_key::<Key, Value>(2).unwrap();
1012
let analysis = dag.analyze().unwrap();
1013
1014
let result = analysis.check_for_overlapping_groups(&groups);
1015
assert!(result.is_ok());
1016
}
1017
}
1018
1019