Path: blob/main/crates/polars-plan/src/plans/aexpr/traverse.rs
6940 views
use super::*;12impl AExpr {3/// Push the inputs of this node to the given container, in reverse order.4/// This ensures the primary node responsible for the name is pushed last.5///6/// This is subtlely different from `children_rev` as this only includes the input expressions,7/// not expressions used during evaluation.8pub fn inputs_rev<E>(&self, container: &mut E)9where10E: Extend<Node>,11{12use AExpr::*;1314match self {15Column(_) | Literal(_) | Len => {},16BinaryExpr { left, op: _, right } => {17container.extend([*right, *left]);18},19Cast { expr, .. } => container.extend([*expr]),20Sort { expr, .. } => container.extend([*expr]),21Gather { expr, idx, .. } => {22container.extend([*idx, *expr]);23},24SortBy { expr, by, .. } => {25container.extend(by.iter().cloned().rev());26container.extend([*expr]);27},28Filter { input, by } => {29container.extend([*by, *input]);30},31Agg(agg_e) => match agg_e.get_input() {32NodeInputs::Single(node) => container.extend([node]),33NodeInputs::Many(nodes) => container.extend(nodes.into_iter().rev()),34NodeInputs::Leaf => {},35},36Ternary {37truthy,38falsy,39predicate,40} => {41container.extend([*predicate, *falsy, *truthy]);42},43AnonymousFunction { input, .. } | Function { input, .. } => {44container.extend(input.iter().rev().map(|e| e.node()))45},46Explode { expr: e, .. } => container.extend([*e]),47Window {48function,49partition_by,50order_by,51options: _,52} => {53if let Some((n, _)) = order_by {54container.extend([*n]);55}56container.extend(partition_by.iter().rev().cloned());57container.extend([*function]);58},59Eval {60expr,61evaluation,62variant: _,63} => {64// We don't use the evaluation here because it does not contain inputs.65_ = evaluation;66container.extend([*expr]);67},68Slice {69input,70offset,71length,72} => {73container.extend([*length, *offset, *input]);74},75}76}7778/// Push the children of this node to the given container, in reverse order.79/// This ensures the primary node responsible for the name is pushed last.80///81/// This is subtlely different from `input_rev` as this only all expressions included in the82/// expression not only the input expressions,83pub fn children_rev<E: Extend<Node>>(&self, container: &mut E) {84use AExpr::*;8586match self {87Column(_) | Literal(_) | Len => {},88BinaryExpr { left, op: _, right } => {89container.extend([*right, *left]);90},91Cast { expr, .. } => container.extend([*expr]),92Sort { expr, .. } => container.extend([*expr]),93Gather { expr, idx, .. } => {94container.extend([*idx, *expr]);95},96SortBy { expr, by, .. } => {97container.extend(by.iter().cloned().rev());98container.extend([*expr]);99},100Filter { input, by } => {101container.extend([*by, *input]);102},103Agg(agg_e) => match agg_e.get_input() {104NodeInputs::Single(node) => container.extend([node]),105NodeInputs::Many(nodes) => container.extend(nodes.into_iter().rev()),106NodeInputs::Leaf => {},107},108Ternary {109truthy,110falsy,111predicate,112} => {113container.extend([*predicate, *falsy, *truthy]);114},115AnonymousFunction { input, .. } | Function { input, .. } => {116container.extend(input.iter().rev().map(|e| e.node()))117},118Explode { expr: e, .. } => container.extend([*e]),119Window {120function,121partition_by,122order_by,123options: _,124} => {125if let Some((n, _)) = order_by {126container.extend([*n]);127}128container.extend(partition_by.iter().rev().cloned());129container.extend([*function]);130},131Eval {132expr,133evaluation,134variant: _,135} => container.extend([*evaluation, *expr]),136Slice {137input,138offset,139length,140} => {141container.extend([*length, *offset, *input]);142},143}144}145146pub fn replace_inputs(mut self, inputs: &[Node]) -> Self {147use AExpr::*;148let input = match &mut self {149Column(_) | Literal(_) | Len => return self,150Cast { expr, .. } => expr,151Explode { expr, .. } => expr,152BinaryExpr { left, right, .. } => {153*left = inputs[0];154*right = inputs[1];155return self;156},157Gather { expr, idx, .. } => {158*expr = inputs[0];159*idx = inputs[1];160return self;161},162Sort { expr, .. } => expr,163SortBy { expr, by, .. } => {164*expr = inputs[0];165by.clear();166by.extend_from_slice(&inputs[1..]);167return self;168},169Filter { input, by, .. } => {170*input = inputs[0];171*by = inputs[1];172return self;173},174Agg(a) => {175match a {176IRAggExpr::Quantile { expr, quantile, .. } => {177*expr = inputs[0];178*quantile = inputs[1];179},180_ => {181a.set_input(inputs[0]);182},183}184return self;185},186Ternary {187truthy,188falsy,189predicate,190} => {191*truthy = inputs[0];192*falsy = inputs[1];193*predicate = inputs[2];194return self;195},196AnonymousFunction { input, .. } | Function { input, .. } => {197assert_eq!(input.len(), inputs.len());198for (e, node) in input.iter_mut().zip(inputs.iter()) {199e.set_node(*node);200}201return self;202},203Eval {204expr,205evaluation,206variant: _,207} => {208*expr = inputs[0];209_ = evaluation; // Intentional.210return self;211},212Slice {213input,214offset,215length,216} => {217*input = inputs[0];218*offset = inputs[1];219*length = inputs[2];220return self;221},222Window {223function,224partition_by,225order_by,226..227} => {228let offset = order_by.is_some() as usize;229*function = inputs[0];230partition_by.clear();231partition_by.extend_from_slice(&inputs[1..inputs.len() - offset]);232if let Some((_, options)) = order_by {233*order_by = Some((*inputs.last().unwrap(), *options));234}235return self;236},237};238*input = inputs[0];239self240}241}242243impl IRAggExpr {244pub fn get_input(&self) -> NodeInputs {245use IRAggExpr::*;246use NodeInputs::*;247match self {248Min { input, .. } => Single(*input),249Max { input, .. } => Single(*input),250Median(input) => Single(*input),251NUnique(input) => Single(*input),252First(input) => Single(*input),253Last(input) => Single(*input),254Mean(input) => Single(*input),255Implode(input) => Single(*input),256Quantile { expr, quantile, .. } => Many(vec![*expr, *quantile]),257Sum(input) => Single(*input),258Count { input, .. } => Single(*input),259Std(input, _) => Single(*input),260Var(input, _) => Single(*input),261AggGroups(input) => Single(*input),262}263}264pub fn set_input(&mut self, input: Node) {265use IRAggExpr::*;266let node = match self {267Min { input, .. } => input,268Max { input, .. } => input,269Median(input) => input,270NUnique(input) => input,271First(input) => input,272Last(input) => input,273Mean(input) => input,274Implode(input) => input,275Quantile { expr, .. } => expr,276Sum(input) => input,277Count { input, .. } => input,278Std(input, _) => input,279Var(input, _) => input,280AggGroups(input) => input,281};282*node = input;283}284}285286pub enum NodeInputs {287Leaf,288Single(Node),289Many(Vec<Node>),290}291292impl NodeInputs {293pub fn first(&self) -> Node {294match self {295NodeInputs::Single(node) => *node,296NodeInputs::Many(nodes) => nodes[0],297NodeInputs::Leaf => panic!(),298}299}300}301302303