Path: blob/main/crates/polars-plan/src/plans/visitor/expr.rs
6940 views
use std::fmt::{Debug, Formatter};12use polars_core::prelude::{Field, Schema};3use polars_utils::unitvec;45use super::*;6use crate::prelude::*;78impl TreeWalker for Expr {9type Arena = ();1011fn apply_children<F: FnMut(&Self, &Self::Arena) -> PolarsResult<VisitRecursion>>(12&self,13op: &mut F,14arena: &Self::Arena,15) -> PolarsResult<VisitRecursion> {16let mut scratch = unitvec![];1718self.nodes(&mut scratch);1920for &child in scratch.as_slice() {21match op(child, arena)? {22// let the recursion continue23VisitRecursion::Continue | VisitRecursion::Skip => {},24// early stop25VisitRecursion::Stop => return Ok(VisitRecursion::Stop),26}27}28Ok(VisitRecursion::Continue)29}3031fn map_children<F: FnMut(Self, &mut Self::Arena) -> PolarsResult<Self>>(32self,33f: &mut F,34_arena: &mut Self::Arena,35) -> PolarsResult<Self> {36use polars_utils::functions::try_arc_map as am;37let mut f = |expr| f(expr, &mut ());38use AggExpr::*;39use Expr::*;40#[rustfmt::skip]41let ret = match self {42Alias(l, r) => Alias(am(l, f)?, r),43Column(_) => self,44Literal(_) => self,45DataTypeFunction(_) => self,46#[cfg(feature = "dtype-struct")]47Field(_) => self,48BinaryExpr { left, op, right } => {49BinaryExpr { left: am(left, &mut f)? , op, right: am(right, f)?}50},51Cast { expr, dtype, options: strict } => Cast { expr: am(expr, f)?, dtype, options: strict },52Sort { expr, options } => Sort { expr: am(expr, f)?, options },53Gather { expr, idx, returns_scalar } => Gather { expr: am(expr, &mut f)?, idx: am(idx, f)?, returns_scalar },54SortBy { expr, by, sort_options } => SortBy { expr: am(expr, &mut f)?, by: by.into_iter().map(f).collect::<Result<_, _>>()?, sort_options },55Agg(agg_expr) => Agg(match agg_expr {56Min { input, propagate_nans } => Min { input: am(input, f)?, propagate_nans },57Max { input, propagate_nans } => Max { input: am(input, f)?, propagate_nans },58Median(x) => Median(am(x, f)?),59NUnique(x) => NUnique(am(x, f)?),60First(x) => First(am(x, f)?),61Last(x) => Last(am(x, f)?),62Mean(x) => Mean(am(x, f)?),63Implode(x) => Implode(am(x, f)?),64Count { input, include_nulls } => Count { input: am(input, f)?, include_nulls },65Quantile { expr, quantile, method: interpol } => Quantile { expr: am(expr, &mut f)?, quantile: am(quantile, f)?, method: interpol },66Sum(x) => Sum(am(x, f)?),67AggGroups(x) => AggGroups(am(x, f)?),68Std(x, ddf) => Std(am(x, f)?, ddf),69Var(x, ddf) => Var(am(x, f)?, ddf),70}),71Ternary { predicate, truthy, falsy } => Ternary { predicate: am(predicate, &mut f)?, truthy: am(truthy, &mut f)?, falsy: am(falsy, f)? },72Function { input, function } => Function { input: input.into_iter().map(f).collect::<Result<_, _>>()?, function },73Explode { input, skip_empty } => Explode { input: am(input, f)?, skip_empty },74Filter { input, by } => Filter { input: am(input, &mut f)?, by: am(by, f)? },75Window { function, partition_by, order_by, options } => {76let partition_by = partition_by.into_iter().map(&mut f).collect::<Result<_, _>>()?;77Window { function: am(function, f)?, partition_by, order_by, options }78},79Slice { input, offset, length } => Slice { input: am(input, &mut f)?, offset: am(offset, &mut f)?, length: am(length, f)? },80KeepName(expr) => KeepName(am(expr, f)?),81Len => Len,82RenameAlias { function, expr } => RenameAlias { function, expr: am(expr, f)? },83AnonymousFunction { input, function, options, fmt_str } => {84AnonymousFunction { input: input.into_iter().map(f).collect::<Result<_, _>>()?, function, options, fmt_str }85},86Eval { expr: input, evaluation, variant } => Eval { expr: am(input, &mut f)?, evaluation: am(evaluation, f)?, variant },87SubPlan(_, _) => self,88Selector(_) => self,89};90Ok(ret)91}92}9394#[derive(Copy, Clone, Debug)]95pub struct AexprNode {96node: Node,97}9899impl AexprNode {100pub fn new(node: Node) -> Self {101Self { node }102}103104/// Get the `Node`.105pub fn node(&self) -> Node {106self.node107}108109pub fn to_aexpr<'a>(&self, arena: &'a Arena<AExpr>) -> &'a AExpr {110arena.get(self.node)111}112113pub fn to_expr(&self, arena: &Arena<AExpr>) -> Expr {114node_to_expr(self.node, arena)115}116117pub fn to_field(&self, schema: &Schema, arena: &Arena<AExpr>) -> PolarsResult<Field> {118let aexpr = arena.get(self.node);119aexpr.to_field(schema, arena)120}121122pub fn assign(&mut self, ae: AExpr, arena: &mut Arena<AExpr>) {123let node = arena.add(ae);124self.node = node;125}126127pub(crate) fn is_leaf(&self, arena: &Arena<AExpr>) -> bool {128matches!(self.to_aexpr(arena), AExpr::Column(_) | AExpr::Literal(_))129}130131pub(crate) fn hashable_and_cmp<'a>(&self, arena: &'a Arena<AExpr>) -> AExprArena<'a> {132AExprArena {133node: self.node,134arena,135}136}137}138139pub struct AExprArena<'a> {140node: Node,141arena: &'a Arena<AExpr>,142}143144impl Debug for AExprArena<'_> {145fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {146write!(f, "AexprArena: {}", self.node.0)147}148}149150impl AExpr {151fn is_equal_node(&self, other: &Self) -> bool {152use AExpr::*;153match (self, other) {154(Column(l), Column(r)) => l == r,155(Literal(l), Literal(r)) => l == r,156(Window { options: l, .. }, Window { options: r, .. }) => l == r,157(158Cast {159options: strict_l,160dtype: dtl,161..162},163Cast {164options: strict_r,165dtype: dtr,166..167},168) => strict_l == strict_r && dtl == dtr,169(Sort { options: l, .. }, Sort { options: r, .. }) => l == r,170(Gather { .. }, Gather { .. })171| (Filter { .. }, Filter { .. })172| (Ternary { .. }, Ternary { .. })173| (Len, Len)174| (Slice { .. }, Slice { .. }) => true,175(176Explode {177expr: _,178skip_empty: l_skip_empty,179},180Explode {181expr: _,182skip_empty: r_skip_empty,183},184) => l_skip_empty == r_skip_empty,185(186SortBy {187sort_options: l_sort_options,188..189},190SortBy {191sort_options: r_sort_options,192..193},194) => l_sort_options == r_sort_options,195(Agg(l), Agg(r)) => l.equal_nodes(r),196(197Function {198input: il,199function: fl,200options: ol,201},202Function {203input: ir,204function: fr,205options: or,206},207) => {208fl == fr && ol == or && {209let mut all_same_name = true;210for (l, r) in il.iter().zip(ir) {211all_same_name &= l.output_name() == r.output_name()212}213214all_same_name215}216},217(AnonymousFunction { .. }, AnonymousFunction { .. }) => false,218(BinaryExpr { op: l, .. }, BinaryExpr { op: r, .. }) => l == r,219_ => false,220}221}222}223224impl<'a> AExprArena<'a> {225pub fn new(node: Node, arena: &'a Arena<AExpr>) -> Self {226Self { node, arena }227}228pub fn to_aexpr(&self) -> &'a AExpr {229self.arena.get(self.node)230}231232// Check single node on equality233pub fn is_equal_single(&self, other: &Self) -> bool {234let self_ae = self.to_aexpr();235let other_ae = other.to_aexpr();236self_ae.is_equal_node(other_ae)237}238}239240impl PartialEq for AExprArena<'_> {241fn eq(&self, other: &Self) -> bool {242let mut scratch1 = unitvec![];243let mut scratch2 = unitvec![];244245scratch1.push(self.node);246scratch2.push(other.node);247248loop {249match (scratch1.pop(), scratch2.pop()) {250(Some(l), Some(r)) => {251let l = Self::new(l, self.arena);252let r = Self::new(r, self.arena);253254if !l.is_equal_single(&r) {255return false;256}257258l.to_aexpr().inputs_rev(&mut scratch1);259r.to_aexpr().inputs_rev(&mut scratch2);260},261(None, None) => return true,262_ => return false,263}264}265}266}267268impl TreeWalker for AexprNode {269type Arena = Arena<AExpr>;270fn apply_children<F: FnMut(&Self, &Self::Arena) -> PolarsResult<VisitRecursion>>(271&self,272op: &mut F,273arena: &Self::Arena,274) -> PolarsResult<VisitRecursion> {275let mut scratch = unitvec![];276277self.to_aexpr(arena).inputs_rev(&mut scratch);278for node in scratch.as_slice() {279let aenode = AexprNode::new(*node);280match op(&aenode, arena)? {281// let the recursion continue282VisitRecursion::Continue | VisitRecursion::Skip => {},283// early stop284VisitRecursion::Stop => return Ok(VisitRecursion::Stop),285}286}287Ok(VisitRecursion::Continue)288}289290fn map_children<F: FnMut(Self, &mut Self::Arena) -> PolarsResult<Self>>(291mut self,292op: &mut F,293arena: &mut Self::Arena,294) -> PolarsResult<Self> {295let mut scratch = unitvec![];296297let ae = arena.get(self.node).clone();298ae.inputs_rev(&mut scratch);299300// rewrite the nodes301for node in scratch.as_mut_slice() {302let aenode = AexprNode::new(*node);303*node = op(aenode, arena)?.node;304}305306scratch.as_mut_slice().reverse();307let ae = ae.replace_inputs(&scratch);308self.node = arena.add(ae);309Ok(self)310}311}312313314