use std::time::Instant;
use parking_lot::Mutex;
use polars_error::PolarsResult;
use slotmap::{Key, SecondaryMap, SlotMap};
use crate::execute::StreamingExecutionState;
use crate::metrics::GraphMetrics;
use crate::nodes::ComputeNode;
slotmap::new_key_type! {
pub struct GraphNodeKey;
pub struct LogicalPipeKey;
}
#[derive(Default)]
pub struct Graph {
pub nodes: SlotMap<GraphNodeKey, GraphNode>,
pub pipes: SlotMap<LogicalPipeKey, LogicalPipe>,
}
impl Graph {
pub fn with_capacity(capacity: usize) -> Self {
Self {
nodes: SlotMap::with_capacity_and_key(capacity),
pipes: SlotMap::with_capacity_and_key(capacity),
}
}
pub fn add_node<N: ComputeNode + 'static>(
&mut self,
node: N,
inputs: impl IntoIterator<Item = (GraphNodeKey, usize)>,
) -> GraphNodeKey {
let node_key = self.nodes.insert(GraphNode {
compute: Box::new(node),
inputs: Vec::new(),
outputs: Vec::new(),
});
for (recv_port, (sender, send_port)) in inputs.into_iter().enumerate() {
let pipe = LogicalPipe {
sender,
send_port,
send_state: PortState::Blocked,
receiver: node_key,
recv_port,
recv_state: PortState::Blocked,
};
let pipe_key = self.pipes.insert(pipe);
self.nodes[node_key].inputs.push(pipe_key);
if self.nodes[sender].outputs.len() <= send_port {
self.nodes[sender]
.outputs
.resize(send_port + 1, LogicalPipeKey::null());
}
assert!(self.nodes[sender].outputs[send_port].is_null());
self.nodes[sender].outputs[send_port] = pipe_key;
}
node_key
}
pub fn update_all_states(
&mut self,
state: &StreamingExecutionState,
metrics: Option<&Mutex<GraphMetrics>>,
) -> PolarsResult<()> {
let mut to_update: Vec<_> = self.nodes.keys().collect();
let mut scheduled_for_update: SecondaryMap<GraphNodeKey, ()> =
self.nodes.keys().map(|k| (k, ())).collect();
let verbose = polars_config::config().verbose();
let mut recv_state = Vec::new();
let mut send_state = Vec::new();
while let Some(node_key) = to_update.pop() {
scheduled_for_update.remove(node_key);
let node = &mut self.nodes[node_key];
recv_state.clear();
send_state.clear();
recv_state.extend(node.inputs.iter().map(|i| self.pipes[*i].send_state));
send_state.extend(node.outputs.iter().map(|o| self.pipes[*o].recv_state));
if verbose {
eprintln!(
"updating {}, before: {recv_state:?} {send_state:?}",
node.compute.name()
);
}
let start = (metrics.is_some() || verbose).then(Instant::now);
if let Some(lock) = metrics {
lock.lock().start_state_update(node_key);
}
node.compute
.update_state(&mut recv_state, &mut send_state, state)?;
let elapsed = start.map(|s| s.elapsed());
if let Some(lock) = metrics {
let is_done = recv_state.iter().all(|s| *s == PortState::Done)
&& send_state.iter().all(|s| *s == PortState::Done);
lock.lock()
.stop_state_update(node_key, elapsed.unwrap(), is_done);
}
if verbose {
eprintln!(
"updating {}, after: {recv_state:?} {send_state:?} (took {:?})",
node.compute.name(),
elapsed.unwrap()
);
}
for (input, state) in node.inputs.iter().zip(recv_state.iter()) {
let pipe = &mut self.pipes[*input];
if pipe.recv_state != *state {
assert!(
pipe.recv_state != PortState::Done,
"implementation error: state transition from Done to Blocked/Ready attempted"
);
pipe.recv_state = *state;
if scheduled_for_update.insert(pipe.sender, ()).is_none() {
to_update.push(pipe.sender);
}
}
}
for (output, state) in node.outputs.iter().zip(send_state.iter()) {
let pipe = &mut self.pipes[*output];
if pipe.send_state != *state {
assert!(
pipe.send_state != PortState::Done,
"implementation error: state transition from Done to Blocked/Ready attempted"
);
pipe.send_state = *state;
if scheduled_for_update.insert(pipe.receiver, ()).is_none() {
to_update.push(pipe.receiver);
}
}
}
}
Ok(())
}
}
pub struct GraphNode {
pub compute: Box<dyn ComputeNode>,
pub inputs: Vec<LogicalPipeKey>,
pub outputs: Vec<LogicalPipeKey>,
}
#[allow(unused)]
#[derive(Clone)]
pub struct LogicalPipe {
pub sender: GraphNodeKey,
pub send_port: usize,
pub send_state: PortState,
pub receiver: GraphNodeKey,
pub recv_port: usize,
pub recv_state: PortState,
}
#[derive(Copy, Clone, PartialEq, Eq, Debug, PartialOrd, Ord)]
pub enum PortState {
Blocked,
Ready,
Done,
}