use polars_error::PolarsResult;
use slotmap::{Key, SecondaryMap, SlotMap};
use crate::execute::StreamingExecutionState;
use crate::nodes::ComputeNode;
slotmap::new_key_type! {
pub struct GraphNodeKey;
pub struct LogicalPipeKey;
}
#[derive(Default)]
pub struct Graph {
pub nodes: SlotMap<GraphNodeKey, GraphNode>,
pub pipes: SlotMap<LogicalPipeKey, LogicalPipe>,
}
impl Graph {
pub fn with_capacity(capacity: usize) -> Self {
Self {
nodes: SlotMap::with_capacity_and_key(capacity),
pipes: SlotMap::with_capacity_and_key(capacity),
}
}
pub fn add_node<N: ComputeNode + 'static>(
&mut self,
node: N,
inputs: impl IntoIterator<Item = (GraphNodeKey, usize)>,
) -> GraphNodeKey {
let node_key = self.nodes.insert(GraphNode {
compute: Box::new(node),
inputs: Vec::new(),
outputs: Vec::new(),
});
for (recv_port, (sender, send_port)) in inputs.into_iter().enumerate() {
let pipe = LogicalPipe {
sender,
send_port,
send_state: PortState::Blocked,
receiver: node_key,
recv_port,
recv_state: PortState::Blocked,
};
let pipe_key = self.pipes.insert(pipe);
self.nodes[node_key].inputs.push(pipe_key);
if self.nodes[sender].outputs.len() <= send_port {
self.nodes[sender]
.outputs
.resize(send_port + 1, LogicalPipeKey::null());
}
assert!(self.nodes[sender].outputs[send_port].is_null());
self.nodes[sender].outputs[send_port] = pipe_key;
}
node_key
}
pub fn update_all_states(&mut self, state: &StreamingExecutionState) -> PolarsResult<()> {
let mut to_update: Vec<_> = self.nodes.keys().collect();
let mut scheduled_for_update: SecondaryMap<GraphNodeKey, ()> =
self.nodes.keys().map(|k| (k, ())).collect();
let verbose = std::env::var("POLARS_VERBOSE_STATE_UPDATE").as_deref() == Ok("1");
let mut recv_state = Vec::new();
let mut send_state = Vec::new();
while let Some(node_key) = to_update.pop() {
scheduled_for_update.remove(node_key);
let node = &mut self.nodes[node_key];
recv_state.clear();
send_state.clear();
recv_state.extend(node.inputs.iter().map(|i| self.pipes[*i].send_state));
send_state.extend(node.outputs.iter().map(|o| self.pipes[*o].recv_state));
if verbose {
eprintln!(
"updating {}, before: {recv_state:?} {send_state:?}",
node.compute.name()
);
}
node.compute
.update_state(&mut recv_state, &mut send_state, state)?;
if verbose {
eprintln!(
"updating {}, after: {recv_state:?} {send_state:?}",
node.compute.name()
);
}
for (input, state) in node.inputs.iter().zip(recv_state.iter()) {
let pipe = &mut self.pipes[*input];
if pipe.recv_state != *state {
assert!(
pipe.recv_state != PortState::Done,
"implementation error: state transition from Done to Blocked/Ready attempted"
);
pipe.recv_state = *state;
if scheduled_for_update.insert(pipe.sender, ()).is_none() {
to_update.push(pipe.sender);
}
}
}
for (output, state) in node.outputs.iter().zip(send_state.iter()) {
let pipe = &mut self.pipes[*output];
if pipe.send_state != *state {
assert!(
pipe.send_state != PortState::Done,
"implementation error: state transition from Done to Blocked/Ready attempted"
);
pipe.send_state = *state;
if scheduled_for_update.insert(pipe.receiver, ()).is_none() {
to_update.push(pipe.receiver);
}
}
}
}
Ok(())
}
}
pub struct GraphNode {
pub compute: Box<dyn ComputeNode>,
pub inputs: Vec<LogicalPipeKey>,
pub outputs: Vec<LogicalPipeKey>,
}
#[allow(unused)]
pub struct LogicalPipe {
pub sender: GraphNodeKey,
pub send_port: usize,
pub send_state: PortState,
pub receiver: GraphNodeKey,
pub recv_port: usize,
pub recv_state: PortState,
}
#[derive(Copy, Clone, PartialEq, Eq, Debug, PartialOrd, Ord)]
pub enum PortState {
Blocked,
Ready,
Done,
}