Path: blob/main/crates/polars-plan/src/plans/visitor/expr.rs
8446 views
use std::fmt::{Debug, Formatter};12use polars_core::prelude::{Field, Schema};3use polars_utils::unitvec;45use super::*;6use crate::prelude::*;78impl TreeWalker for Expr {9type Arena = ();1011fn apply_children<F: FnMut(&Self, &Self::Arena) -> PolarsResult<VisitRecursion>>(12&self,13op: &mut F,14arena: &Self::Arena,15) -> PolarsResult<VisitRecursion> {16let mut scratch = unitvec![];1718self.nodes(&mut scratch);1920for &child in scratch.as_slice() {21match op(child, arena)? {22// let the recursion continue23VisitRecursion::Continue | VisitRecursion::Skip => {},24// early stop25VisitRecursion::Stop => return Ok(VisitRecursion::Stop),26}27}28Ok(VisitRecursion::Continue)29}3031fn map_children<F: FnMut(Self, &mut Self::Arena) -> PolarsResult<Self>>(32self,33f: &mut F,34_arena: &mut Self::Arena,35) -> PolarsResult<Self> {36use polars_utils::functions::try_arc_map as am;37let mut f = |expr| f(expr, &mut ());38use AggExpr::*;39use Expr::*;40#[rustfmt::skip]41let ret = match self {42Alias(l, r) => Alias(am(l, f)?, r),43Column(_) => self,44Literal(_) => self,45DataTypeFunction(_) => self,46#[cfg(feature = "dtype-struct")]47Field(_) => self,48BinaryExpr { left, op, right } => {49BinaryExpr { left: am(left, &mut f)? , op, right: am(right, f)?}50},51Cast { expr, dtype, options: strict } => Cast { expr: am(expr, f)?, dtype, options: strict },52Sort { expr, options } => Sort { expr: am(expr, f)?, options },53Gather { expr, idx, returns_scalar, null_on_oob } => Gather {54expr: am(expr, &mut f)?,55idx: am(idx, f)?,56returns_scalar,57null_on_oob,58},59SortBy { expr, by, sort_options } => SortBy { expr: am(expr, &mut f)?, by: by.into_iter().map(f).collect::<Result<_, _>>()?, sort_options },60Agg(agg_expr) => Agg(match agg_expr {61Min { input, propagate_nans } => Min { input: am(input, f)?, propagate_nans },62Max { input, propagate_nans } => Max { input: am(input, f)?, propagate_nans },63Median(x) => Median(am(x, f)?),64NUnique(x) => NUnique(am(x, f)?),65First(x) => First(am(x, f)?),66FirstNonNull(x) => FirstNonNull(am(x, f)?),67Last(x) => Last(am(x, f)?),68LastNonNull(x) => LastNonNull(am(x, f)?),69Item { input, allow_empty } => Item { input: am(input, f)?, allow_empty },70Mean(x) => Mean(am(x, f)?),71Implode(x) => Implode(am(x, f)?),72Count { input, include_nulls } => Count { input: am(input, f)?, include_nulls },73Quantile { expr, quantile, method: interpol } => Quantile { expr: am(expr, &mut f)?, quantile: am(quantile, f)?, method: interpol },74Sum(x) => Sum(am(x, f)?),75AggGroups(x) => AggGroups(am(x, f)?),76Std(x, ddf) => Std(am(x, f)?, ddf),77Var(x, ddf) => Var(am(x, f)?, ddf),7879}),80Ternary { predicate, truthy, falsy } => Ternary { predicate: am(predicate, &mut f)?, truthy: am(truthy, &mut f)?, falsy: am(falsy, f)? },81Function { input, function } => Function { input: input.into_iter().map(f).collect::<Result<_, _>>()?, function },82Explode { input, options } => Explode { input: am(input, f)?, options },83Filter { input, by } => Filter { input: am(input, &mut f)?, by: am(by, f)? },84#[cfg(feature = "dynamic_group_by")]85Rolling { function, index_column, period, offset, closed_window } => Rolling { function: am(function, &mut f)?, index_column: am(index_column, &mut f)?, period, offset, closed_window },86Over { function, partition_by, order_by, mapping } => {87let partition_by = partition_by.into_iter().map(&mut f).collect::<Result<_, _>>()?;88Over { function: am(function, f)?, partition_by, order_by, mapping }89},90Slice { input, offset, length } => Slice { input: am(input, &mut f)?, offset: am(offset, &mut f)?, length: am(length, f)? },91KeepName(expr) => KeepName(am(expr, f)?),92Element => Element,93Len => Len,94RenameAlias { function, expr } => RenameAlias { function, expr: am(expr, f)? },95Display { inputs, fmt_str } => {96Display { inputs: inputs.into_iter().map(f).collect::<Result<_, _>>()?, fmt_str }97},98AnonymousFunction { input, function, options, fmt_str } => {99AnonymousFunction { input: input.into_iter().map(f).collect::<Result<_, _>>()?, function, options, fmt_str }100},101Eval { expr: input, evaluation, variant } => Eval { expr: am(input, &mut f)?, evaluation: am(evaluation, f)?, variant },102#[cfg(feature = "dtype-struct")]103StructEval { expr: input, evaluation } => {104StructEval { expr: am(input, &mut f)?, evaluation: evaluation.into_iter().map(f).collect::<Result<_, _>>()? }105},106SubPlan(_, _) => self,107Selector(_) => self,108};109Ok(ret)110}111}112113#[derive(Copy, Clone, Debug)]114pub struct AexprNode {115node: Node,116}117118impl AexprNode {119pub fn new(node: Node) -> Self {120Self { node }121}122123/// Get the `Node`.124pub fn node(&self) -> Node {125self.node126}127128pub fn to_aexpr<'a>(&self, arena: &'a Arena<AExpr>) -> &'a AExpr {129arena.get(self.node)130}131132pub fn to_expr(&self, arena: &Arena<AExpr>) -> Expr {133node_to_expr(self.node, arena)134}135136pub fn to_field(&self, schema: &Schema, arena: &Arena<AExpr>) -> PolarsResult<Field> {137let aexpr = arena.get(self.node);138aexpr.to_field(&ToFieldContext::new(arena, schema))139}140141pub fn assign(&mut self, ae: AExpr, arena: &mut Arena<AExpr>) {142let node = arena.add(ae);143self.node = node;144}145146pub(crate) fn is_leaf(&self, arena: &Arena<AExpr>) -> bool {147matches!(self.to_aexpr(arena), AExpr::Column(_) | AExpr::Literal(_))148}149150pub(crate) fn hashable_and_cmp<'a>(&self, arena: &'a Arena<AExpr>) -> AExprArena<'a> {151AExprArena {152node: self.node,153arena,154}155}156}157158pub struct AExprArena<'a> {159node: Node,160arena: &'a Arena<AExpr>,161}162163impl Debug for AExprArena<'_> {164fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {165write!(f, "AexprArena: {}", self.node.0)166}167}168169impl AExpr {170fn is_equal_node(&self, other: &Self) -> bool {171use AExpr::*;172match (self, other) {173(Column(l), Column(r)) => l == r,174(Literal(l), Literal(r)) => l == r,175#[cfg(feature = "dynamic_group_by")]176(177Rolling {178function: _,179index_column: _,180period: l_period,181offset: l_offset,182closed_window: l_closed_window,183},184Rolling {185function: _,186index_column: _,187period: r_period,188offset: r_offset,189closed_window: r_closed_window,190},191) => l_period == r_period && l_offset == r_offset && l_closed_window == r_closed_window,192(Over { mapping: l, .. }, Over { mapping: r, .. }) => l == r,193(194Cast {195options: strict_l,196dtype: dtl,197..198},199Cast {200options: strict_r,201dtype: dtr,202..203},204) => strict_l == strict_r && dtl == dtr,205(Sort { options: l, .. }, Sort { options: r, .. }) => l == r,206(Gather { .. }, Gather { .. })207| (Filter { .. }, Filter { .. })208| (Ternary { .. }, Ternary { .. })209| (Len, Len)210| (Slice { .. }, Slice { .. }) => true,211(212Explode {213expr: _,214options: l_options,215},216Explode {217expr: _,218options: r_options,219},220) => l_options == r_options,221(222SortBy {223sort_options: l_sort_options,224..225},226SortBy {227sort_options: r_sort_options,228..229},230) => l_sort_options == r_sort_options,231(Agg(l), Agg(r)) => l.equal_nodes(r),232(233Function {234input: il,235function: fl,236options: ol,237},238Function {239input: ir,240function: fr,241options: or,242},243) => {244fl == fr && ol == or && {245let mut all_same_name = true;246for (l, r) in il.iter().zip(ir) {247all_same_name &= l.output_name() == r.output_name()248}249250all_same_name251}252},253(254AnonymousFunction {255function: l1,256options: l2,257fmt_str: l3,258input: _,259},260AnonymousFunction {261function: r1,262options: r2,263fmt_str: r3,264input: _,265},266) => {267l2 == r2 && l3 == r3 && {268use LazySerde as L;269match (l1, r1) {270// We only check the pointers, so this works for python271// functions that are on the same address.272(L::Deserialized(l0), L::Deserialized(r0)) => l0 == r0,273(L::Bytes(l0), L::Bytes(r0)) => l0 == r0,274(275L::Named {276name: l_name,277payload: l_payload,278value: l_value,279},280L::Named {281name: r_name,282payload: r_payload,283value: r_value,284},285) => l_name == r_name && l_payload == r_payload && l_value == r_value,286_ => false,287}288}289},290(BinaryExpr { op: l, .. }, BinaryExpr { op: r, .. }) => l == r,291_ => false,292}293}294}295296impl<'a> AExprArena<'a> {297pub fn new(node: Node, arena: &'a Arena<AExpr>) -> Self {298Self { node, arena }299}300pub fn to_aexpr(&self) -> &'a AExpr {301self.arena.get(self.node)302}303304// Check single node on equality305pub fn is_equal_single(&self, other: &Self) -> bool {306let self_ae = self.to_aexpr();307let other_ae = other.to_aexpr();308self_ae.is_equal_node(other_ae)309}310}311312impl PartialEq for AExprArena<'_> {313fn eq(&self, other: &Self) -> bool {314let mut scratch1 = unitvec![];315let mut scratch2 = unitvec![];316317scratch1.push(self.node);318scratch2.push(other.node);319320loop {321match (scratch1.pop(), scratch2.pop()) {322(Some(l), Some(r)) => {323let l = Self::new(l, self.arena);324let r = Self::new(r, other.arena);325326if !l.is_equal_single(&r) {327return false;328}329330l.to_aexpr().inputs_rev(&mut scratch1);331r.to_aexpr().inputs_rev(&mut scratch2);332},333(None, None) => return true,334_ => return false,335}336}337}338}339340impl TreeWalker for AexprNode {341type Arena = Arena<AExpr>;342fn apply_children<F: FnMut(&Self, &Self::Arena) -> PolarsResult<VisitRecursion>>(343&self,344op: &mut F,345arena: &Self::Arena,346) -> PolarsResult<VisitRecursion> {347let mut scratch = unitvec![];348349self.to_aexpr(arena).inputs_rev(&mut scratch);350for node in scratch.as_slice() {351let aenode = AexprNode::new(*node);352match op(&aenode, arena)? {353// let the recursion continue354VisitRecursion::Continue | VisitRecursion::Skip => {},355// early stop356VisitRecursion::Stop => return Ok(VisitRecursion::Stop),357}358}359Ok(VisitRecursion::Continue)360}361362fn map_children<F: FnMut(Self, &mut Self::Arena) -> PolarsResult<Self>>(363mut self,364op: &mut F,365arena: &mut Self::Arena,366) -> PolarsResult<Self> {367let mut scratch = unitvec![];368369let ae = arena.get(self.node).clone();370ae.inputs_rev(&mut scratch);371372// rewrite the nodes373for node in scratch.as_mut_slice() {374let aenode = AexprNode::new(*node);375*node = op(aenode, arena)?.node;376}377378scratch.as_mut_slice().reverse();379let ae = ae.replace_inputs(&scratch);380self.node = arena.add(ae);381Ok(self)382}383}384385386