Path: blob/main/crates/polars-stream/src/physical_plan/lower_expr.rs
6939 views
use std::sync::Arc;12use polars_core::chunked_array::cast::CastOptions;3use polars_core::frame::DataFrame;4use polars_core::prelude::{5DataType, Field, IDX_DTYPE, InitHashMaps, PlHashMap, PlHashSet, PlIndexMap,6};7use polars_core::schema::{Schema, SchemaExt};8use polars_error::PolarsResult;9use polars_expr::state::ExecutionState;10use polars_expr::{ExpressionConversionState, create_physical_expr};11use polars_ops::frame::{JoinArgs, JoinType};12use polars_ops::series::{RLE_LENGTH_COLUMN_NAME, RLE_VALUE_COLUMN_NAME};13use polars_plan::plans::AExpr;14use polars_plan::plans::expr_ir::{ExprIR, OutputName};15use polars_plan::prelude::*;16use polars_utils::arena::{Arena, Node};17use polars_utils::itertools::Itertools;18use polars_utils::pl_str::PlSmallStr;19use polars_utils::{unique_column_name, unitvec};20use slotmap::SlotMap;2122use super::fmt::fmt_exprs;23use super::{PhysNode, PhysNodeKey, PhysNodeKind, PhysStream, StreamingLowerIRContext};24use crate::physical_plan::lower_group_by::build_group_by_stream;25use crate::physical_plan::lower_ir::{build_filter_stream, build_row_idx_stream};2627type ExprNodeKey = Node;2829pub(crate) struct ExprCache {30is_elementwise: PlHashMap<Node, bool>,31is_input_independent: PlHashMap<Node, bool>,32is_length_preserving: PlHashMap<Node, bool>,33}3435impl ExprCache {36pub fn with_capacity(capacity: usize) -> Self {37Self {38is_elementwise: PlHashMap::with_capacity(capacity),39is_input_independent: PlHashMap::with_capacity(capacity),40is_length_preserving: PlHashMap::with_capacity(capacity),41}42}43}4445struct LowerExprContext<'a> {46prepare_visualization: bool,47expr_arena: &'a mut Arena<AExpr>,48phys_sm: &'a mut SlotMap<PhysNodeKey, PhysNode>,49cache: &'a mut ExprCache,50}5152impl<'a> From<LowerExprContext<'a>> for StreamingLowerIRContext {53fn from(value: LowerExprContext<'a>) -> Self {54Self {55prepare_visualization: value.prepare_visualization,56}57}58}59impl<'a> From<&LowerExprContext<'a>> for StreamingLowerIRContext {60fn from(value: &LowerExprContext<'a>) -> Self {61Self {62prepare_visualization: value.prepare_visualization,63}64}65}6667pub(crate) fn is_fake_elementwise_function(expr: &AExpr) -> bool {68// The in-memory engine treats ApplyList as elementwise but this is not actually69// the case. It doesn't cause any problems for the in-memory engine because of70// how it does the execution but it causes errors for new-streaming.7172// Some other functions are also marked as elementwise for filter pushdown73// but aren't actually elementwise (e.g. arguments aren't same length).74match expr {75AExpr::Function { function, .. } => {76use IRFunctionExpr as F;77match function {78#[cfg(feature = "is_in")]79F::Boolean(IRBooleanFunction::IsIn { .. }) => true,80#[cfg(feature = "replace")]81F::Replace | F::ReplaceStrict { .. } => true,82_ => false,83}84},85_ => false,86}87}8889pub(crate) fn is_elementwise_rec_cached(90expr_key: ExprNodeKey,91arena: &Arena<AExpr>,92cache: &mut ExprCache,93) -> bool {94if !cache.is_elementwise.contains_key(&expr_key) {95cache.is_elementwise.insert(96expr_key,97(|| {98let mut expr_key = expr_key;99let mut stack = unitvec![];100101loop {102let ae = arena.get(expr_key);103104if is_fake_elementwise_function(ae) {105return false;106}107108if !polars_plan::plans::is_elementwise(&mut stack, ae, arena) {109return false;110}111112let Some(next_key) = stack.pop() else {113break;114};115116expr_key = next_key;117}118119true120})(),121);122}123124*cache.is_elementwise.get(&expr_key).unwrap()125}126127#[recursive::recursive]128pub fn is_input_independent_rec(129expr_key: ExprNodeKey,130arena: &Arena<AExpr>,131cache: &mut PlHashMap<ExprNodeKey, bool>,132) -> bool {133if let Some(ret) = cache.get(&expr_key) {134return *ret;135}136137let ret = match arena.get(expr_key) {138AExpr::Explode { expr: inner, .. }139| AExpr::Cast {140expr: inner,141dtype: _,142options: _,143}144| AExpr::Sort {145expr: inner,146options: _,147} => is_input_independent_rec(*inner, arena, cache),148AExpr::Column(_) => false,149AExpr::Literal(_) => true,150AExpr::BinaryExpr { left, op: _, right } => {151is_input_independent_rec(*left, arena, cache)152&& is_input_independent_rec(*right, arena, cache)153},154AExpr::Gather {155expr,156idx,157returns_scalar: _,158} => {159is_input_independent_rec(*expr, arena, cache)160&& is_input_independent_rec(*idx, arena, cache)161},162AExpr::SortBy {163expr,164by,165sort_options: _,166} => {167is_input_independent_rec(*expr, arena, cache)168&& by169.iter()170.all(|expr| is_input_independent_rec(*expr, arena, cache))171},172AExpr::Filter { input, by } => {173is_input_independent_rec(*input, arena, cache)174&& is_input_independent_rec(*by, arena, cache)175},176AExpr::Agg(agg_expr) => match agg_expr.get_input() {177polars_plan::plans::NodeInputs::Leaf => true,178polars_plan::plans::NodeInputs::Single(expr) => {179is_input_independent_rec(expr, arena, cache)180},181polars_plan::plans::NodeInputs::Many(exprs) => exprs182.iter()183.all(|expr| is_input_independent_rec(*expr, arena, cache)),184},185AExpr::Ternary {186predicate,187truthy,188falsy,189} => {190is_input_independent_rec(*predicate, arena, cache)191&& is_input_independent_rec(*truthy, arena, cache)192&& is_input_independent_rec(*falsy, arena, cache)193},194AExpr::AnonymousFunction {195input,196function: _,197options: _,198fmt_str: _,199}200| AExpr::Function {201input,202function: _,203options: _,204} => input205.iter()206.all(|expr| is_input_independent_rec(expr.node(), arena, cache)),207AExpr::Eval {208expr,209evaluation: _,210variant: _,211} => is_input_independent_rec(*expr, arena, cache),212AExpr::Window {213function,214partition_by,215order_by,216options: _,217} => {218is_input_independent_rec(*function, arena, cache)219&& partition_by220.iter()221.all(|expr| is_input_independent_rec(*expr, arena, cache))222&& order_by223.iter()224.all(|(expr, _options)| is_input_independent_rec(*expr, arena, cache))225},226AExpr::Slice {227input,228offset,229length,230} => {231is_input_independent_rec(*input, arena, cache)232&& is_input_independent_rec(*offset, arena, cache)233&& is_input_independent_rec(*length, arena, cache)234},235AExpr::Len => false,236};237238cache.insert(expr_key, ret);239ret240}241242pub fn is_input_independent(243expr_key: ExprNodeKey,244expr_arena: &Arena<AExpr>,245cache: &mut ExprCache,246) -> bool {247is_input_independent_rec(expr_key, expr_arena, &mut cache.is_input_independent)248}249250fn is_input_independent_ctx(expr_key: ExprNodeKey, ctx: &mut LowerExprContext) -> bool {251is_input_independent_rec(252expr_key,253ctx.expr_arena,254&mut ctx.cache.is_input_independent,255)256}257258fn build_input_independent_node_with_ctx(259exprs: &[ExprIR],260ctx: &mut LowerExprContext,261) -> PolarsResult<PhysNodeKey> {262let output_schema = compute_output_schema(&Schema::default(), exprs, ctx.expr_arena)?;263Ok(ctx.phys_sm.insert(PhysNode::new(264output_schema,265PhysNodeKind::InputIndependentSelect {266selectors: exprs.to_vec(),267},268)))269}270271#[recursive::recursive]272pub fn is_length_preserving_rec(273expr_key: ExprNodeKey,274arena: &Arena<AExpr>,275cache: &mut PlHashMap<ExprNodeKey, bool>,276) -> bool {277if let Some(ret) = cache.get(&expr_key) {278return *ret;279}280281let ret = match arena.get(expr_key) {282AExpr::Gather { .. }283| AExpr::Explode { .. }284| AExpr::Filter { .. }285| AExpr::Agg(_)286| AExpr::Slice { .. }287| AExpr::Len288| AExpr::Literal(_) => false,289290AExpr::Column(_) => true,291292AExpr::Cast {293expr: inner,294dtype: _,295options: _,296}297| AExpr::Sort {298expr: inner,299options: _,300}301| AExpr::SortBy {302expr: inner,303by: _,304sort_options: _,305} => is_length_preserving_rec(*inner, arena, cache),306307AExpr::BinaryExpr { left, op: _, right } => {308// As long as at least one input is length-preserving the other side309// should either broadcast or have the same length.310is_length_preserving_rec(*left, arena, cache)311|| is_length_preserving_rec(*right, arena, cache)312},313AExpr::Ternary {314predicate,315truthy,316falsy,317} => {318is_length_preserving_rec(*predicate, arena, cache)319|| is_length_preserving_rec(*truthy, arena, cache)320|| is_length_preserving_rec(*falsy, arena, cache)321},322AExpr::AnonymousFunction {323input,324function: _,325options,326fmt_str: _,327}328| AExpr::Function {329input,330function: _,331options,332} => {333// FIXME: actually inspect the functions? This is overly conservative.334options.is_length_preserving()335&& input336.iter()337.all(|expr| is_length_preserving_rec(expr.node(), arena, cache))338},339AExpr::Eval { .. } => true,340AExpr::Window {341function: _, // Actually shouldn't matter for window functions.342partition_by: _,343order_by: _,344options,345} => !matches!(options, WindowType::Over(WindowMapping::Explode)),346};347348cache.insert(expr_key, ret);349ret350}351352#[expect(dead_code)]353pub fn is_length_preserving(354expr_key: ExprNodeKey,355expr_arena: &Arena<AExpr>,356cache: &mut ExprCache,357) -> bool {358is_length_preserving_rec(expr_key, expr_arena, &mut cache.is_length_preserving)359}360361fn is_length_preserving_ctx(expr_key: ExprNodeKey, ctx: &mut LowerExprContext) -> bool {362is_length_preserving_rec(363expr_key,364ctx.expr_arena,365&mut ctx.cache.is_length_preserving,366)367}368369fn build_fallback_node_with_ctx(370input: PhysStream,371exprs: &[ExprIR],372ctx: &mut LowerExprContext,373) -> PolarsResult<PhysNodeKey> {374// Pre-select only the columns that are needed for this fallback expression.375let input_schema = &ctx.phys_sm[input.node].output_schema;376let mut select_names: PlHashSet<_> = exprs377.iter()378.flat_map(|expr| polars_plan::utils::aexpr_to_leaf_names_iter(expr.node(), ctx.expr_arena))379.collect();380// To keep the length correct we have to ensure we select at least one381// column.382if select_names.is_empty() {383if let Some(name) = input_schema.iter_names().next() {384select_names.insert(name.clone());385}386}387let input_stream = if input_schema388.iter_names()389.any(|name| !select_names.contains(name.as_str()))390{391let select_exprs = select_names392.into_iter()393.map(|name| {394ExprIR::new(395ctx.expr_arena.add(AExpr::Column(name.clone())),396OutputName::ColumnLhs(name),397)398})399.collect_vec();400build_select_stream_with_ctx(input, &select_exprs, ctx)?401} else {402input403};404405let output_schema = schema_for_select(input_stream, exprs, ctx)?;406let mut conv_state = ExpressionConversionState::new(false);407let phys_exprs = exprs408.iter()409.map(|expr| {410create_physical_expr(411expr,412Context::Default,413ctx.expr_arena,414&ctx.phys_sm[input_stream.node].output_schema,415&mut conv_state,416)417})418.try_collect_vec()?;419let map = move |df| {420let exec_state = ExecutionState::new();421let columns = phys_exprs422.iter()423.map(|phys_expr| phys_expr.evaluate(&df, &exec_state))424.try_collect()?;425DataFrame::new_with_broadcast(columns)426};427428let format_str = ctx.prepare_visualization.then(|| {429let mut buffer = String::new();430buffer.push_str("SELECT [\n");431fmt_exprs(432&mut buffer,433exprs,434ctx.expr_arena,435super::fmt::FormatExprStyle::Select,436);437buffer.push(']');438buffer439});440let kind = PhysNodeKind::InMemoryMap {441input: input_stream,442map: Arc::new(map),443format_str,444};445Ok(ctx.phys_sm.insert(PhysNode::new(output_schema, kind)))446}447448fn simplify_input_streams(449orig_input: PhysStream,450mut input_streams: PlHashSet<PhysStream>,451ctx: &mut LowerExprContext,452) -> PolarsResult<PlHashSet<PhysStream>> {453// Flatten nested zips (ensures the original input columns only occur once).454if input_streams.len() > 1 {455let mut flattened_input_streams = PlHashSet::with_capacity(input_streams.len());456for input_stream in input_streams {457if let PhysNodeKind::Zip {458inputs,459null_extend: false,460} = &ctx.phys_sm[input_stream.node].kind461{462flattened_input_streams.extend(inputs);463ctx.phys_sm.remove(input_stream.node);464} else {465flattened_input_streams.insert(input_stream);466}467}468input_streams = flattened_input_streams;469}470471// Merge reduce nodes that directly operate on the original input.472let mut combined_exprs = vec![];473input_streams = input_streams474.into_iter()475.filter(|input_stream| {476if let PhysNodeKind::Reduce {477input: inner,478exprs,479} = &ctx.phys_sm[input_stream.node].kind480{481if *inner == orig_input {482combined_exprs.extend(exprs.iter().cloned());483ctx.phys_sm.remove(input_stream.node);484return false;485}486}487true488})489.collect();490if !combined_exprs.is_empty() {491let output_schema = schema_for_select(orig_input, &combined_exprs, ctx)?;492let kind = PhysNodeKind::Reduce {493input: orig_input,494exprs: combined_exprs,495};496let reduce_node_key = ctx.phys_sm.insert(PhysNode::new(output_schema, kind));497input_streams.insert(PhysStream::first(reduce_node_key));498}499500Ok(input_streams)501}502503// Assuming that agg_node is a single-input reduction, lowers its input recursively504// and returns a Reduce node as well a node corresponding to the column to select505// from the Reduce node for the aggregate.506fn lower_unary_reduce_node(507input: PhysStream,508agg_node: Node,509ctx: &mut LowerExprContext,510) -> PolarsResult<(PhysStream, Node)> {511let agg_aexpr = ctx.expr_arena.get(agg_node).clone();512let mut agg_input = Vec::with_capacity(1);513agg_aexpr.inputs_rev(&mut agg_input);514assert!(agg_input.len() == 1);515516let (trans_input, trans_exprs) = lower_exprs_with_ctx(input, &agg_input, ctx)?;517let trans_agg_node = ctx.expr_arena.add(agg_aexpr.replace_inputs(&trans_exprs));518519let out_name = unique_column_name();520let expr_ir = ExprIR::new(trans_agg_node, OutputName::Alias(out_name.clone()));521let output_schema = schema_for_select(trans_input, std::slice::from_ref(&expr_ir), ctx)?;522let kind = PhysNodeKind::Reduce {523input: trans_input,524exprs: vec![expr_ir],525};526527let reduce_node_key = ctx.phys_sm.insert(PhysNode::new(output_schema, kind));528let reduce_stream = PhysStream::first(reduce_node_key);529let out_node = ctx.expr_arena.add(AExpr::Column(out_name));530Ok((reduce_stream, out_node))531}532533// In the recursive lowering we don't bother with named expressions at all, so534// we work directly with Nodes.535#[recursive::recursive]536fn lower_exprs_with_ctx(537input: PhysStream,538exprs: &[Node],539ctx: &mut LowerExprContext,540) -> PolarsResult<(PhysStream, Vec<Node>)> {541// We have to catch this case separately, in case all the input independent expressions are elementwise.542// TODO: we shouldn't always do this when recursing, e.g. pl.col.a.sum() + 1 will still hit this in the recursion.543if exprs.iter().all(|e| is_input_independent_ctx(*e, ctx)) {544let expr_irs = exprs545.iter()546.map(|e| ExprIR::new(*e, OutputName::Alias(unique_column_name())))547.collect_vec();548let node = build_input_independent_node_with_ctx(&expr_irs, ctx)?;549let out_exprs = expr_irs550.iter()551.map(|e| ctx.expr_arena.add(AExpr::Column(e.output_name().clone())))552.collect();553return Ok((PhysStream::first(node), out_exprs));554}555556// Fallback expressions that can directly be applied to the original input.557let mut fallback_subset = Vec::new();558559// Streams containing the columns used for executing transformed expressions.560let mut input_streams = PlHashSet::new();561562// The final transformed expressions that will be selected from the zipped563// together transformed nodes.564let mut transformed_exprs = Vec::with_capacity(exprs.len());565566for expr in exprs.iter().copied() {567if is_elementwise_rec_cached(expr, ctx.expr_arena, ctx.cache) {568if !is_input_independent_ctx(expr, ctx) {569input_streams.insert(input);570}571transformed_exprs.push(expr);572continue;573}574575match ctx.expr_arena.get(expr).clone() {576AExpr::Explode {577expr: inner,578skip_empty,579} => {580// While explode is streamable, it is not elementwise, so we581// have to transform it to a select node.582let (trans_input, trans_exprs) = lower_exprs_with_ctx(input, &[inner], ctx)?;583let exploded_name = unique_column_name();584let trans_inner = ctx.expr_arena.add(AExpr::Explode {585expr: trans_exprs[0],586skip_empty,587});588let explode_expr =589ExprIR::new(trans_inner, OutputName::Alias(exploded_name.clone()));590let output_schema =591schema_for_select(trans_input, std::slice::from_ref(&explode_expr), ctx)?;592let node_kind = PhysNodeKind::Select {593input: trans_input,594selectors: vec![explode_expr.clone()],595extend_original: false,596};597let node_key = ctx.phys_sm.insert(PhysNode::new(output_schema, node_kind));598input_streams.insert(PhysStream::first(node_key));599transformed_exprs.push(ctx.expr_arena.add(AExpr::Column(exploded_name)));600},601AExpr::Column(_) => unreachable!("column should always be streamable"),602AExpr::Literal(_) => {603let out_name = unique_column_name();604let inner_expr = ExprIR::new(expr, OutputName::Alias(out_name.clone()));605let node_key = build_input_independent_node_with_ctx(&[inner_expr], ctx)?;606input_streams.insert(PhysStream::first(node_key));607transformed_exprs.push(ctx.expr_arena.add(AExpr::Column(out_name)));608},609610AExpr::Function {611input: ref inner_exprs,612function: IRFunctionExpr::Repeat,613options: _,614} => {615assert!(inner_exprs.len() == 2);616let out_name = unique_column_name();617let value_expr_ir = inner_exprs[0].with_alias(out_name.clone());618let repeats_expr_ir = inner_exprs[1].clone();619let value_stream = build_select_stream_with_ctx(input, &[value_expr_ir], ctx)?;620let repeats_stream = build_select_stream_with_ctx(input, &[repeats_expr_ir], ctx)?;621622let output_schema = ctx.phys_sm[value_stream.node].output_schema.clone();623let kind = PhysNodeKind::Repeat {624value: value_stream,625repeats: repeats_stream,626};627let repeat_node_key = ctx.phys_sm.insert(PhysNode::new(output_schema, kind));628input_streams.insert(PhysStream::first(repeat_node_key));629transformed_exprs.push(ctx.expr_arena.add(AExpr::Column(out_name)));630},631632AExpr::Function {633input: ref inner_exprs,634function: IRFunctionExpr::ExtendConstant,635options: _,636} => {637assert!(inner_exprs.len() == 3);638let input_schema = &ctx.phys_sm[input.node].output_schema;639let out_name = unique_column_name();640let first_ir = inner_exprs[0].with_alias(out_name.clone());641let out_dtype = first_ir.dtype(input_schema, ctx.expr_arena)?;642let mut value_expr_ir = inner_exprs[1].with_alias(out_name.clone());643let repeats_expr_ir = inner_exprs[2].clone();644645// Cast the value if necessary.646if value_expr_ir.dtype(input_schema, ctx.expr_arena)? != out_dtype {647let cast_expr = AExpr::Cast {648expr: value_expr_ir.node(),649dtype: out_dtype.clone(),650options: CastOptions::NonStrict,651};652value_expr_ir = ExprIR::new(653ctx.expr_arena.add(cast_expr),654OutputName::Alias(out_name.clone()),655);656}657658let first_stream = build_select_stream_with_ctx(input, &[first_ir], ctx)?;659let value_stream = build_select_stream_with_ctx(input, &[value_expr_ir], ctx)?;660let repeats_stream = build_select_stream_with_ctx(input, &[repeats_expr_ir], ctx)?;661662let output_schema = ctx.phys_sm[first_stream.node].output_schema.clone();663let repeat_kind = PhysNodeKind::Repeat {664value: value_stream,665repeats: repeats_stream,666};667let repeat_node_key = ctx668.phys_sm669.insert(PhysNode::new(output_schema.clone(), repeat_kind));670671let concat_kind = PhysNodeKind::OrderedUnion {672inputs: vec![first_stream, PhysStream::first(repeat_node_key)],673};674let concat_node_key = ctx675.phys_sm676.insert(PhysNode::new(output_schema, concat_kind));677input_streams.insert(PhysStream::first(concat_node_key));678transformed_exprs.push(ctx.expr_arena.add(AExpr::Column(out_name)));679},680681AExpr::Function {682input: ref inner_exprs,683function: IRFunctionExpr::ConcatExpr(_rechunk),684options: _,685} => {686// We have to lower each expression separately as they might have different lengths.687let mut concat_streams = Vec::new();688let out_name = unique_column_name();689for inner_expr in inner_exprs {690let (trans_input, trans_expr) =691lower_exprs_with_ctx(input, &[inner_expr.node()], ctx)?;692let select_expr =693ExprIR::new(trans_expr[0], OutputName::Alias(out_name.clone()));694concat_streams.push(build_select_stream_with_ctx(695trans_input,696&[select_expr],697ctx,698)?);699}700701let output_schema = ctx.phys_sm[concat_streams[0].node].output_schema.clone();702let node_kind = PhysNodeKind::OrderedUnion {703inputs: concat_streams,704};705let node_key = ctx.phys_sm.insert(PhysNode::new(output_schema, node_kind));706input_streams.insert(PhysStream::first(node_key));707transformed_exprs.push(ctx.expr_arena.add(AExpr::Column(out_name)));708},709710AExpr::Function {711input: ref inner_exprs,712function: IRFunctionExpr::Unique(maintain_order),713options: _,714} => {715assert!(inner_exprs.len() == 1);716// Lower to no-aggregate group-by with unique name.717let tmp_name = unique_column_name();718let (trans_input, trans_inner_exprs) =719lower_exprs_with_ctx(input, &[inner_exprs[0].node()], ctx)?;720let group_by_key_expr =721ExprIR::new(trans_inner_exprs[0], OutputName::Alias(tmp_name.clone()));722let group_by_output_schema =723schema_for_select(trans_input, std::slice::from_ref(&group_by_key_expr), ctx)?;724let group_by_stream = build_group_by_stream(725trans_input,726&[group_by_key_expr],727&[],728group_by_output_schema,729maintain_order,730Arc::new(GroupbyOptions::default()),731None,732ctx.expr_arena,733ctx.phys_sm,734ctx.cache,735StreamingLowerIRContext::from(&*ctx),736)?;737input_streams.insert(group_by_stream);738transformed_exprs.push(ctx.expr_arena.add(AExpr::Column(tmp_name)));739},740741AExpr::Function {742input: ref inner_exprs,743function: IRFunctionExpr::UniqueCounts,744options: _,745} => {746// Transform:747// expr.unique_counts().alias(name)748// ->749// .select(expr.alias(name))750// .group_by(_ = name, maintain_order=True)751// .agg(name = pl.len())752// .select(name)753754assert_eq!(inner_exprs.len(), 1);755756let input_schema = &ctx.phys_sm[input.node].output_schema;757758let key_name = unique_column_name();759let tmp_count_name = unique_column_name();760761let input_expr = &inner_exprs[0];762let output_dtype = input_expr.dtype(input_schema, ctx.expr_arena)?.clone();763let group_by_output_schema = Arc::new(Schema::from_iter([764(key_name.clone(), output_dtype),765(tmp_count_name.clone(), IDX_DTYPE),766]));767768let keys = [input_expr.with_alias(key_name)];769let aggs = [ExprIR::new(770ctx.expr_arena.add(AExpr::Len),771OutputName::Alias(tmp_count_name.clone()),772)];773774let stream = build_group_by_stream(775input,776&keys,777&aggs,778group_by_output_schema,779true,780Default::default(),781None,782ctx.expr_arena,783ctx.phys_sm,784ctx.cache,785StreamingLowerIRContext {786prepare_visualization: ctx.prepare_visualization,787},788)?;789transformed_exprs.push(ctx.expr_arena.add(AExpr::Column(tmp_count_name)));790input_streams.insert(stream);791},792AExpr::Function {793input: ref inner_exprs,794function:795IRFunctionExpr::ValueCounts {796sort: false,797parallel: _,798name: count_name,799normalize: false,800},801options: _,802} => {803// Transform:804// expr.value_counts(805// sort=False,806// parallel=_,807// name=count_name,808// normalize=False809// ).alias(name)810// ->811// .select(expr.alias(name))812// .group_by(name)813// .agg(count_name = pl.len())814// .select(pl.struct([name, count_name]))815816assert_eq!(inner_exprs.len(), 1);817818let input_schema = &ctx.phys_sm[input.node].output_schema;819820let tmp_value_name = unique_column_name();821let tmp_count_name = unique_column_name();822823let input_expr = &inner_exprs[0];824let output_field = input_expr.field(input_schema, ctx.expr_arena)?;825let group_by_output_schema = Arc::new(Schema::from_iter([826output_field.clone().with_name(tmp_value_name.clone()),827Field::new(tmp_count_name.clone(), IDX_DTYPE),828]));829830let keys = [input_expr.with_alias(tmp_value_name.clone())];831let aggs = [ExprIR::new(832ctx.expr_arena.add(AExpr::Len),833OutputName::Alias(tmp_count_name.clone()),834)];835836let stream = build_group_by_stream(837input,838&keys,839&aggs,840group_by_output_schema,841false,842Default::default(),843None,844ctx.expr_arena,845ctx.phys_sm,846ctx.cache,847StreamingLowerIRContext {848prepare_visualization: ctx.prepare_visualization,849},850)?;851852let value = ExprIR::new(853ctx.expr_arena.add(AExpr::Column(tmp_value_name)),854OutputName::Alias(output_field.name),855);856let count = ExprIR::new(857ctx.expr_arena.add(AExpr::Column(tmp_count_name)),858OutputName::Alias(count_name.clone()),859);860861transformed_exprs.push(862AExprBuilder::function(863vec![value, count],864IRFunctionExpr::AsStruct,865ctx.expr_arena,866)867.node(),868);869input_streams.insert(stream);870},871872AExpr::Function {873input: ref inner_exprs,874function: IRFunctionExpr::ArgUnique,875options: _,876} => {877// Transform:878// expr.arg_unique()879// ->880// .with_row_index(IDX)881// .group_by(expr)882// .agg(IDX = IDX.first())883// .select(IDX.sort())884885assert_eq!(inner_exprs.len(), 1);886887let expr_name = unique_column_name();888let idx_name = unique_column_name();889890let stream = build_select_stream_with_ctx(891input,892&[inner_exprs[0].with_alias(expr_name.clone())],893ctx,894)?;895896let mut group_by_output_schema =897ctx.phys_sm[stream.node].output_schema.as_ref().clone();898group_by_output_schema.insert(idx_name.clone(), IDX_DTYPE);899900let stream = build_row_idx_stream(stream, idx_name.clone(), None, ctx.phys_sm);901902let keys =903[AExprBuilder::col(expr_name.clone(), ctx.expr_arena).expr_ir(expr_name)];904let aggs = [AExprBuilder::col(idx_name.clone(), ctx.expr_arena)905.first(ctx.expr_arena)906.expr_ir(idx_name.clone())];907908let stream = build_group_by_stream(909stream,910&keys,911&aggs,912Arc::new(group_by_output_schema),913false,914Default::default(),915None,916ctx.expr_arena,917ctx.phys_sm,918ctx.cache,919StreamingLowerIRContext {920prepare_visualization: ctx.prepare_visualization,921},922)?;923924let expr = AExprBuilder::col(idx_name.clone(), ctx.expr_arena)925.sort(Default::default(), ctx.expr_arena)926.expr_ir(idx_name.clone());927let stream = build_select_stream_with_ctx(stream, &[expr], ctx)?;928929transformed_exprs.push(AExprBuilder::col(idx_name.clone(), ctx.expr_arena).node());930input_streams.insert(stream);931},932933#[cfg(feature = "is_in")]934AExpr::Function {935input: ref inner_exprs,936function: IRFunctionExpr::Boolean(IRBooleanFunction::IsIn { nulls_equal }),937options: _,938} if is_scalar_ae(inner_exprs[1].node(), ctx.expr_arena) => {939// Translate left and right side separately (they could have different lengths).940let left_on_name = unique_column_name();941let right_on_name = unique_column_name();942let (trans_input_left, trans_expr_left) =943lower_exprs_with_ctx(input, &[inner_exprs[0].node()], ctx)?;944let right_expr_exploded_node = match ctx.expr_arena.get(inner_exprs[1].node()) {945// expr.implode().explode() ~= expr (and avoids rechunking)946AExpr::Agg(IRAggExpr::Implode(n)) => *n,947_ => ctx.expr_arena.add(AExpr::Explode {948expr: inner_exprs[1].node(),949skip_empty: true,950}),951};952let (trans_input_right, trans_expr_right) =953lower_exprs_with_ctx(input, &[right_expr_exploded_node], ctx)?;954955// We have to ensure the left input has the right name for the semi-anti-join to956// generate the correct output name.957let left_col_expr = ctx.expr_arena.add(AExpr::Column(left_on_name.clone()));958let left_select_stream = build_select_stream_with_ctx(959trans_input_left,960&[ExprIR::new(961trans_expr_left[0],962OutputName::Alias(left_on_name.clone()),963)],964ctx,965)?;966967let node_kind = PhysNodeKind::SemiAntiJoin {968input_left: left_select_stream,969input_right: trans_input_right,970left_on: vec![ExprIR::new(971left_col_expr,972OutputName::Alias(left_on_name.clone()),973)],974right_on: vec![ExprIR::new(975trans_expr_right[0],976OutputName::Alias(right_on_name),977)],978args: JoinArgs {979how: JoinType::Semi,980validation: Default::default(),981suffix: None,982slice: None,983nulls_equal,984coalesce: Default::default(),985maintain_order: Default::default(),986},987output_bool: true,988};989990// SemiAntiJoin with output_bool returns a column with the same name as the first991// input column.992let output_schema = Schema::from_iter([(left_on_name.clone(), DataType::Boolean)]);993let node_key = ctx994.phys_sm995.insert(PhysNode::new(Arc::new(output_schema), node_kind));996input_streams.insert(PhysStream::first(node_key));997transformed_exprs.push(left_col_expr);998},9991000#[cfg(feature = "cum_agg")]1001AExpr::Function {1002input: ref inner_exprs,1003function:1004ref function @ (IRFunctionExpr::CumMin { reverse }1005| IRFunctionExpr::CumMax { reverse }1006| IRFunctionExpr::CumSum { reverse }1007| IRFunctionExpr::CumCount { reverse }1008| IRFunctionExpr::CumProd { reverse }),1009options: _,1010} if !reverse => {1011use crate::nodes::cum_agg::CumAggKind;10121013assert_eq!(inner_exprs.len(), 1);10141015let input_schema = &ctx.phys_sm[input.node].output_schema;10161017let value_key = unique_column_name();1018let value_dtype = inner_exprs[0].dtype(input_schema, ctx.expr_arena)?;10191020let input = build_select_stream_with_ctx(1021input,1022&[inner_exprs[0].with_alias(value_key.clone())],1023ctx,1024)?;1025let kind = match function {1026IRFunctionExpr::CumMin { .. } => CumAggKind::Min,1027IRFunctionExpr::CumMax { .. } => CumAggKind::Max,1028IRFunctionExpr::CumSum { .. } => CumAggKind::Sum,1029IRFunctionExpr::CumCount { .. } => CumAggKind::Count,1030IRFunctionExpr::CumProd { .. } => CumAggKind::Prod,1031_ => unreachable!(),1032};1033let node_kind = PhysNodeKind::CumAgg { input, kind };10341035let output_schema = Schema::from_iter([(value_key.clone(), value_dtype.clone())]);1036let node_key = ctx1037.phys_sm1038.insert(PhysNode::new(Arc::new(output_schema), node_kind));1039input_streams.insert(PhysStream::first(node_key));1040transformed_exprs.push(ctx.expr_arena.add(AExpr::Column(value_key)));1041},10421043#[cfg(feature = "diff")]1044AExpr::Function {1045input: ref inner_exprs,1046function: IRFunctionExpr::Diff(null_behavior),1047options: _,1048} => {1049use polars_core::scalar::Scalar;1050use polars_core::series::ops::NullBehavior;10511052assert_eq!(inner_exprs.len(), 2);10531054// Transform:1055// expr.diff(offset, "ignore")1056// ->1057// expr - expr.shift(offset)10581059let base_expr_ir = &inner_exprs[0];1060let base_dtype =1061base_expr_ir.dtype(&ctx.phys_sm[input.node].output_schema, ctx.expr_arena)?;1062let offset_expr_ir = &inner_exprs[1];1063let offset_dtype =1064offset_expr_ir.dtype(&ctx.phys_sm[input.node].output_schema, ctx.expr_arena)?;10651066let mut base = AExprBuilder::new_from_node(base_expr_ir.node());1067let cast_dtype = match base_dtype {1068DataType::UInt8 => Some(DataType::Int16),1069DataType::UInt16 => Some(DataType::Int32),1070DataType::UInt32 | DataType::UInt64 => Some(DataType::Int64),1071_ => None,1072};1073if let Some(dtype) = cast_dtype {1074base = base.cast(dtype, ctx.expr_arena);1075}10761077let mut offset = AExprBuilder::new_from_node(offset_expr_ir.node());1078if offset_dtype != &DataType::Int64 {1079offset = offset.cast(DataType::Int64, ctx.expr_arena);1080}10811082let shifted = base.shift(offset.node(), ctx.expr_arena);1083let mut output = base.minus(shifted.node(), ctx.expr_arena);10841085if null_behavior == NullBehavior::Drop {1086// Without the column size, slice can only remove leading nulls.1087// So if the offset was negative, the nulls appeared at the end of the column.1088// In that case, shift the column forward to move the nulls back to the front.1089let zero_literal =1090AExprBuilder::lit(LiteralValue::new_idxsize(0), ctx.expr_arena);1091let offset_neg = offset.negate(ctx.expr_arena);1092let offset_if_negative = AExprBuilder::function(1093vec![offset_neg.expr_ir_unnamed(), zero_literal.expr_ir_unnamed()],1094IRFunctionExpr::MaxHorizontal,1095ctx.expr_arena,1096);1097output = output.shift(offset_if_negative, ctx.expr_arena);10981099// Remove the nulls that were introduced by the shift1100let offset_abs = offset.abs(ctx.expr_arena);1101let null_literal = AExprBuilder::lit(1102LiteralValue::Scalar(Scalar::null(DataType::Int64)),1103ctx.expr_arena,1104);1105output = output.slice(offset_abs, null_literal, ctx.expr_arena);1106}11071108let (stream, nodes) = lower_exprs_with_ctx(input, &[output.node()], ctx)?;1109input_streams.insert(stream);1110transformed_exprs.extend(nodes);1111},11121113AExpr::Function {1114input: ref inner_exprs,1115function: IRFunctionExpr::RLE,1116options: _,1117} => {1118assert_eq!(inner_exprs.len(), 1);11191120let input_schema = &ctx.phys_sm[input.node].output_schema;11211122let value_key = unique_column_name();1123let value_dtype = inner_exprs[0].dtype(input_schema, ctx.expr_arena)?;11241125let input = build_select_stream_with_ctx(1126input,1127&[inner_exprs[0].with_alias(value_key.clone())],1128ctx,1129)?;1130let node_kind = PhysNodeKind::Rle(input);11311132let output_schema = Schema::from_iter([(1133value_key.clone(),1134DataType::Struct(vec![1135Field::new(1136PlSmallStr::from_static(RLE_VALUE_COLUMN_NAME),1137value_dtype.clone(),1138),1139Field::new(PlSmallStr::from_static(RLE_LENGTH_COLUMN_NAME), IDX_DTYPE),1140]),1141)]);1142let node_key = ctx1143.phys_sm1144.insert(PhysNode::new(Arc::new(output_schema), node_kind));1145input_streams.insert(PhysStream::first(node_key));1146transformed_exprs.push(ctx.expr_arena.add(AExpr::Column(value_key)));1147},11481149AExpr::Function {1150input: ref inner_exprs,1151function: IRFunctionExpr::RLEID,1152options: _,1153} => {1154assert_eq!(inner_exprs.len(), 1);11551156let value_key = unique_column_name();11571158let input = build_select_stream_with_ctx(1159input,1160&[inner_exprs[0].with_alias(value_key.clone())],1161ctx,1162)?;1163let node_kind = PhysNodeKind::RleId(input);11641165let output_schema = Schema::from_iter([(value_key.clone(), IDX_DTYPE)]);1166let node_key = ctx1167.phys_sm1168.insert(PhysNode::new(Arc::new(output_schema), node_kind));1169input_streams.insert(PhysStream::first(node_key));1170transformed_exprs.push(ctx.expr_arena.add(AExpr::Column(value_key.clone())));1171},11721173AExpr::Function {1174input: ref inner_exprs,1175function: ref function @ (IRFunctionExpr::PeakMin | IRFunctionExpr::PeakMax),1176options: _,1177} => {1178assert_eq!(inner_exprs.len(), 1);11791180let value_key = unique_column_name();11811182let input = build_select_stream_with_ctx(1183input,1184&[inner_exprs[0].with_alias(value_key.clone())],1185ctx,1186)?;1187let is_peak_max = matches!(function, IRFunctionExpr::PeakMax);1188let node_kind = PhysNodeKind::PeakMinMax { input, is_peak_max };11891190let output_schema = Schema::from_iter([(value_key.clone(), DataType::Boolean)]);1191let node_key = ctx1192.phys_sm1193.insert(PhysNode::new(Arc::new(output_schema), node_kind));1194input_streams.insert(PhysStream::first(node_key));1195transformed_exprs.push(ctx.expr_arena.add(AExpr::Column(value_key.clone())));1196},11971198// pl.row_index() maps to this.1199#[cfg(feature = "range")]1200AExpr::Function {1201input: ref inner_exprs,1202function: IRFunctionExpr::Range(IRRangeFunction::IntRange { step: 1, dtype }),1203options: _,1204} if {1205let start_is_zero = match ctx.expr_arena.get(inner_exprs[0].node()) {1206AExpr::Literal(lit) => lit.extract_usize().ok() == Some(0),1207_ => false,1208};1209let stop_is_len = matches!(ctx.expr_arena.get(inner_exprs[1].node()), AExpr::Len);12101211dtype == DataType::IDX_DTYPE && start_is_zero && stop_is_len1212} =>1213{1214let out_name = unique_column_name();1215let row_idx_col_aexpr = ctx.expr_arena.add(AExpr::Column(out_name.clone()));1216let row_idx_col_expr_ir =1217ExprIR::new(row_idx_col_aexpr, OutputName::ColumnLhs(out_name.clone()));1218let row_idx_stream = build_select_stream_with_ctx(1219build_row_idx_stream(input, out_name, None, ctx.phys_sm),1220&[row_idx_col_expr_ir],1221ctx,1222)?;1223input_streams.insert(row_idx_stream);1224transformed_exprs.push(row_idx_col_aexpr);1225},12261227#[cfg(feature = "range")]1228AExpr::Function {1229input: ref inner_exprs,1230function: IRFunctionExpr::Range(IRRangeFunction::IntRange { step: 1, dtype }),1231options: _,1232} if {1233let start_is_zero = match ctx.expr_arena.get(inner_exprs[0].node()) {1234AExpr::Literal(lit) => lit.extract_usize().ok() == Some(0),1235_ => false,1236};1237let stop_is_count = matches!(1238ctx.expr_arena.get(inner_exprs[1].node()),1239AExpr::Agg(IRAggExpr::Count { .. })1240);12411242start_is_zero && stop_is_count1243} =>1244{1245let AExpr::Agg(IRAggExpr::Count {1246input: input_expr,1247include_nulls,1248}) = ctx.expr_arena.get(inner_exprs[1].node())1249else {1250unreachable!();1251};1252let (input_expr, include_nulls) = (*input_expr, *include_nulls);12531254let out_name = unique_column_name();1255let mut row_idx_col_aexpr = ctx.expr_arena.add(AExpr::Column(out_name.clone()));1256if dtype != IDX_DTYPE {1257row_idx_col_aexpr = AExprBuilder::new_from_node(row_idx_col_aexpr)1258.cast(dtype, ctx.expr_arena)1259.node();1260}1261let row_idx_col_expr_ir =1262ExprIR::new(row_idx_col_aexpr, OutputName::ColumnLhs(out_name.clone()));12631264let mut input_expr = AExprBuilder::new_from_node(input_expr);1265if !include_nulls {1266input_expr = input_expr.drop_nulls(ctx.expr_arena);1267}1268let input_expr = input_expr.expr_ir_retain_name(ctx.expr_arena);12691270let row_idx_stream = build_select_stream_with_ctx(1271build_row_idx_stream(1272build_select_stream_with_ctx(input, &[input_expr], ctx)?,1273out_name,1274None,1275ctx.phys_sm,1276),1277&[row_idx_col_expr_ir],1278ctx,1279)?;1280input_streams.insert(row_idx_stream);1281transformed_exprs.push(row_idx_col_aexpr);1282},12831284// Lower arbitrary elementwise functions.1285ref node @ AExpr::Function {1286input: ref inner_exprs,1287options,1288..1289}1290| ref node @ AExpr::AnonymousFunction {1291input: ref inner_exprs,1292options,1293..1294} if options.is_elementwise() && !is_fake_elementwise_function(node) => {1295let inner_nodes = inner_exprs.iter().map(|expr| expr.node()).collect_vec();1296let (trans_input, trans_exprs) = lower_exprs_with_ctx(input, &inner_nodes, ctx)?;12971298// The function may be sensitive to names (e.g. pl.struct), so we restore them.1299let new_input = trans_exprs1300.into_iter()1301.zip(inner_exprs)1302.map(|(trans, orig)| {1303ExprIR::new(trans, OutputName::Alias(orig.output_name().clone()))1304})1305.collect_vec();1306let mut new_node = node.clone();1307match &mut new_node {1308AExpr::Function { input, .. } | AExpr::AnonymousFunction { input, .. } => {1309*input = new_input;1310},1311_ => unreachable!(),1312}1313input_streams.insert(trans_input);1314transformed_exprs.push(ctx.expr_arena.add(new_node));1315},13161317// Lower arbitrary row-separable functions.1318ref node @ AExpr::Function {1319input: ref inner_exprs,1320ref function,1321options,1322} if options.is_row_separable() && !is_fake_elementwise_function(node) => {1323// While these functions are streamable, they are not elementwise, so we1324// have to transform them to a select node.1325let inner_nodes = inner_exprs.iter().map(|x| x.node()).collect_vec();1326let (trans_input, trans_exprs) = lower_exprs_with_ctx(input, &inner_nodes, ctx)?;1327let out_name = unique_column_name();1328let trans_inner = ctx.expr_arena.add(AExpr::Function {1329input: trans_exprs1330.iter()1331.map(|node| ExprIR::from_node(*node, ctx.expr_arena))1332.collect(),1333function: function.clone(),1334options,1335});1336let func_expr = ExprIR::new(trans_inner, OutputName::Alias(out_name.clone()));1337let output_schema =1338schema_for_select(trans_input, std::slice::from_ref(&func_expr), ctx)?;1339let node_kind = PhysNodeKind::Select {1340input: trans_input,1341selectors: vec![func_expr.clone()],1342extend_original: false,1343};1344let node_key = ctx.phys_sm.insert(PhysNode::new(output_schema, node_kind));1345input_streams.insert(PhysStream::first(node_key));1346transformed_exprs.push(ctx.expr_arena.add(AExpr::Column(out_name)));1347},13481349AExpr::BinaryExpr { left, op, right } => {1350let (trans_input, trans_exprs) = lower_exprs_with_ctx(input, &[left, right], ctx)?;1351let bin_expr = AExpr::BinaryExpr {1352left: trans_exprs[0],1353op,1354right: trans_exprs[1],1355};1356input_streams.insert(trans_input);1357transformed_exprs.push(ctx.expr_arena.add(bin_expr));1358},1359AExpr::Eval {1360expr: inner,1361evaluation,1362variant,1363} => match variant {1364EvalVariant::List => {1365let (trans_input, trans_expr) = lower_exprs_with_ctx(input, &[inner], ctx)?;1366let eval_expr = AExpr::Eval {1367expr: trans_expr[0],1368evaluation,1369variant,1370};1371input_streams.insert(trans_input);1372transformed_exprs.push(ctx.expr_arena.add(eval_expr));1373},1374EvalVariant::Cumulative { .. } => {1375// Cumulative is not elementwise, this would need a special node.1376let out_name = unique_column_name();1377fallback_subset.push(ExprIR::new(expr, OutputName::Alias(out_name.clone())));1378transformed_exprs.push(ctx.expr_arena.add(AExpr::Column(out_name)));1379},1380},1381AExpr::Ternary {1382predicate,1383truthy,1384falsy,1385} => {1386let (trans_input, trans_exprs) =1387lower_exprs_with_ctx(input, &[predicate, truthy, falsy], ctx)?;1388let tern_expr = AExpr::Ternary {1389predicate: trans_exprs[0],1390truthy: trans_exprs[1],1391falsy: trans_exprs[2],1392};1393input_streams.insert(trans_input);1394transformed_exprs.push(ctx.expr_arena.add(tern_expr));1395},1396AExpr::Cast {1397expr: inner,1398dtype,1399options,1400} => {1401let (trans_input, trans_exprs) = lower_exprs_with_ctx(input, &[inner], ctx)?;1402input_streams.insert(trans_input);1403transformed_exprs.push(ctx.expr_arena.add(AExpr::Cast {1404expr: trans_exprs[0],1405dtype,1406options,1407}));1408},1409AExpr::Sort {1410expr: inner,1411options,1412} => {1413// As we'll refer to the sorted column twice, ensure the inner1414// expr is available as a column by selecting first.1415let sorted_name = unique_column_name();1416let inner_expr_ir = ExprIR::new(inner, OutputName::Alias(sorted_name.clone()));1417let select_stream =1418build_select_stream_with_ctx(input, std::slice::from_ref(&inner_expr_ir), ctx)?;1419let col_expr = ctx.expr_arena.add(AExpr::Column(sorted_name.clone()));1420let kind = PhysNodeKind::Sort {1421input: select_stream,1422by_column: vec![ExprIR::new(col_expr, OutputName::Alias(sorted_name))],1423slice: None,1424sort_options: (&options).into(),1425};1426let output_schema = ctx.phys_sm[select_stream.node].output_schema.clone();1427let node_key = ctx.phys_sm.insert(PhysNode::new(output_schema, kind));1428input_streams.insert(PhysStream::first(node_key));1429transformed_exprs.push(col_expr);1430},14311432AExpr::SortBy {1433expr: inner,1434by,1435sort_options,1436} => {1437// Select our inputs (if we don't do this we'll waste time sorting irrelevant columns).1438let sorted_name = unique_column_name();1439let by_names = by.iter().map(|_| unique_column_name()).collect_vec();1440let all_inner_expr_irs = [(&sorted_name, inner)]1441.into_iter()1442.chain(by_names.iter().zip(by.iter().copied()))1443.map(|(name, inner)| ExprIR::new(inner, OutputName::Alias(name.clone())))1444.collect_vec();1445let select_stream = build_select_stream_with_ctx(input, &all_inner_expr_irs, ctx)?;14461447// Sort the inputs.1448let kind = PhysNodeKind::Sort {1449input: select_stream,1450by_column: by_names1451.into_iter()1452.map(|name| {1453ExprIR::new(1454ctx.expr_arena.add(AExpr::Column(name.clone())),1455OutputName::Alias(name),1456)1457})1458.collect(),1459slice: None,1460sort_options,1461};1462let output_schema = ctx.phys_sm[select_stream.node].output_schema.clone();1463let sort_node_key = ctx.phys_sm.insert(PhysNode::new(output_schema, kind));14641465let sorted_col_expr = ctx.expr_arena.add(AExpr::Column(sorted_name.clone()));1466input_streams.insert(PhysStream::first(sort_node_key));1467transformed_exprs.push(sorted_col_expr);1468},14691470#[cfg(feature = "top_k")]1471AExpr::Function {1472input: inner_exprs,1473function: function @ (IRFunctionExpr::TopK { .. } | IRFunctionExpr::TopKBy { .. }),1474options: _,1475} => {1476// Select our inputs.1477let by = &inner_exprs[2..];1478let out_name = unique_column_name();1479let by_names = by.iter().map(|_| unique_column_name()).collect_vec();1480let data_irs = [(&out_name, &inner_exprs[0])]1481.into_iter()1482.chain(by_names.iter().zip(by.iter()))1483.map(|(name, inner)| ExprIR::new(inner.node(), OutputName::Alias(name.clone())))1484.collect_vec();1485let data_stream = build_select_stream_with_ctx(input, &data_irs, ctx)?;1486let k_stream = build_select_stream_with_ctx(input, &inner_exprs[1..2], ctx)?;14871488// Create 'by' column expressions.1489let out_col_node = ctx.expr_arena.add(AExpr::Column(out_name.clone()));1490let out_col_expr = ExprIR::new(out_col_node, OutputName::Alias(out_name));1491let (by_column, reverse) = match function {1492IRFunctionExpr::TopK { descending } => {1493(vec![out_col_expr.clone()], vec![descending])1494},1495IRFunctionExpr::TopKBy {1496descending: reverse,1497} => {1498let by_column = by_names1499.into_iter()1500.map(|name| {1501ExprIR::new(1502ctx.expr_arena.add(AExpr::Column(name.clone())),1503OutputName::Alias(name),1504)1505})1506.collect();1507(by_column, reverse.clone())1508},1509_ => unreachable!(),1510};15111512let kind = PhysNodeKind::TopK {1513input: data_stream,1514k: k_stream,1515nulls_last: vec![true; by_column.len()],1516reverse,1517by_column,1518};1519let output_schema = ctx.phys_sm[data_stream.node].output_schema.clone();1520let node_key = ctx.phys_sm.insert(PhysNode::new(output_schema, kind));1521input_streams.insert(PhysStream::first(node_key));1522transformed_exprs.push(out_col_node);1523},15241525AExpr::Filter { input: inner, by } => {1526// Select our inputs (if we don't do this we'll waste time filtering irrelevant columns).1527let out_name = unique_column_name();1528let by_name = unique_column_name();1529let inner_expr_ir = ExprIR::new(inner, OutputName::Alias(out_name.clone()));1530let by_expr_ir = ExprIR::new(by, OutputName::Alias(by_name.clone()));1531let select_stream =1532build_select_stream_with_ctx(input, &[inner_expr_ir, by_expr_ir], ctx)?;15331534// Add a filter node.1535let predicate = ExprIR::new(1536ctx.expr_arena.add(AExpr::Column(by_name.clone())),1537OutputName::Alias(by_name),1538);1539let kind = PhysNodeKind::Filter {1540input: select_stream,1541predicate,1542};1543let output_schema = ctx.phys_sm[select_stream.node].output_schema.clone();1544let filter_node_key = ctx.phys_sm.insert(PhysNode::new(output_schema, kind));1545input_streams.insert(PhysStream::first(filter_node_key));1546transformed_exprs.push(ctx.expr_arena.add(AExpr::Column(out_name)));1547},15481549// Aggregates.1550AExpr::Agg(agg) => match agg {1551// Change agg mutably so we can share the codepath for all of these.1552IRAggExpr::Min { .. }1553| IRAggExpr::Max { .. }1554| IRAggExpr::First(_)1555| IRAggExpr::Last(_)1556| IRAggExpr::Sum(_)1557| IRAggExpr::Mean(_)1558| IRAggExpr::Var { .. }1559| IRAggExpr::Std { .. }1560| IRAggExpr::Count { .. } => {1561let (trans_stream, trans_expr) = lower_unary_reduce_node(input, expr, ctx)?;1562input_streams.insert(trans_stream);1563transformed_exprs.push(trans_expr);1564},1565IRAggExpr::NUnique(inner) => {1566// Lower to no-aggregate group-by with unique name feeding into len aggregate.1567let tmp_name = unique_column_name();1568let (trans_input, trans_inner_exprs) =1569lower_exprs_with_ctx(input, &[inner], ctx)?;1570let group_by_key_expr =1571ExprIR::new(trans_inner_exprs[0], OutputName::Alias(tmp_name.clone()));1572let group_by_output_schema = schema_for_select(1573trans_input,1574std::slice::from_ref(&group_by_key_expr),1575ctx,1576)?;1577let group_by_stream = build_group_by_stream(1578trans_input,1579&[group_by_key_expr],1580&[],1581group_by_output_schema,1582false,1583Arc::new(GroupbyOptions::default()),1584None,1585ctx.expr_arena,1586ctx.phys_sm,1587ctx.cache,1588StreamingLowerIRContext::from(&*ctx),1589)?;15901591let len_node = ctx.expr_arena.add(AExpr::Len);1592let len_expr_ir = ExprIR::new(len_node, OutputName::Alias(tmp_name.clone()));1593let output_schema = schema_for_select(1594group_by_stream,1595std::slice::from_ref(&len_expr_ir),1596ctx,1597)?;1598let kind = PhysNodeKind::Reduce {1599input: group_by_stream,1600exprs: vec![len_expr_ir],1601};16021603let reduce_node_key = ctx.phys_sm.insert(PhysNode::new(output_schema, kind));1604input_streams.insert(PhysStream::first(reduce_node_key));1605transformed_exprs.push(ctx.expr_arena.add(AExpr::Column(tmp_name)));1606},1607IRAggExpr::Median(_)1608| IRAggExpr::Implode(_)1609| IRAggExpr::Quantile { .. }1610| IRAggExpr::AggGroups(_) => {1611let out_name = unique_column_name();1612fallback_subset.push(ExprIR::new(expr, OutputName::Alias(out_name.clone())));1613transformed_exprs.push(ctx.expr_arena.add(AExpr::Column(out_name)));1614},1615},16161617#[cfg(feature = "bitwise")]1618AExpr::Function {1619function:1620IRFunctionExpr::Bitwise(1621IRBitwiseFunction::And | IRBitwiseFunction::Or | IRBitwiseFunction::Xor,1622),1623..1624} => {1625let (trans_stream, trans_expr) = lower_unary_reduce_node(input, expr, ctx)?;1626input_streams.insert(trans_stream);1627transformed_exprs.push(trans_expr);1628},16291630AExpr::Function {1631function:1632IRFunctionExpr::Boolean(1633IRBooleanFunction::Any { .. } | IRBooleanFunction::All { .. },1634),1635..1636} => {1637let (trans_stream, trans_expr) = lower_unary_reduce_node(input, expr, ctx)?;1638input_streams.insert(trans_stream);1639transformed_exprs.push(trans_expr);1640},16411642// Length-based expressions.1643AExpr::Len => {1644let out_name = unique_column_name();1645let expr_ir = ExprIR::new(expr, OutputName::Alias(out_name.clone()));1646let output_schema = schema_for_select(input, std::slice::from_ref(&expr_ir), ctx)?;1647let kind = PhysNodeKind::Reduce {1648input,1649exprs: vec![expr_ir],1650};1651let reduce_node_key = ctx.phys_sm.insert(PhysNode::new(output_schema, kind));1652input_streams.insert(PhysStream::first(reduce_node_key));1653transformed_exprs.push(ctx.expr_arena.add(AExpr::Column(out_name)));1654},16551656AExpr::Function {1657input: ref inner_exprs,1658function: IRFunctionExpr::ArgWhere,1659options: _,1660} => {1661// pl.arg_where(expr)1662//1663// ->1664// .select(predicate_name = expr)1665// .with_row_index(out_name)1666// .filter(predicate_name)1667// .select(out_name)1668let out_name = unique_column_name();1669let predicate_name = unique_column_name();1670let predicate = build_select_stream_with_ctx(1671input,1672&[inner_exprs[0].with_alias(predicate_name.clone())],1673ctx,1674)?;1675let row_index =1676build_row_idx_stream(predicate, out_name.clone(), None, ctx.phys_sm);16771678let filter_stream = build_filter_stream(1679row_index,1680AExprBuilder::col(predicate_name.clone(), ctx.expr_arena)1681.expr_ir(predicate_name),1682ctx.expr_arena,1683ctx.phys_sm,1684ctx.cache,1685StreamingLowerIRContext {1686prepare_visualization: ctx.prepare_visualization,1687},1688)?;1689input_streams.insert(filter_stream);1690transformed_exprs.push(AExprBuilder::col(out_name.clone(), ctx.expr_arena).node());1691},16921693AExpr::Slice {1694input: inner,1695offset,1696length,1697} => {1698let out_name = unique_column_name();1699let inner_expr_ir = ExprIR::new(inner, OutputName::Alias(out_name.clone()));1700let offset_expr_ir = ExprIR::from_node(offset, ctx.expr_arena);1701let length_expr_ir = ExprIR::from_node(length, ctx.expr_arena);1702let input_stream = build_select_stream_with_ctx(input, &[inner_expr_ir], ctx)?;1703let offset_stream = build_select_stream_with_ctx(input, &[offset_expr_ir], ctx)?;1704let length_stream = build_select_stream_with_ctx(input, &[length_expr_ir], ctx)?;17051706let output_schema = ctx.phys_sm[input_stream.node].output_schema.clone();1707let kind = PhysNodeKind::DynamicSlice {1708input: input_stream,1709offset: offset_stream,1710length: length_stream,1711};1712let slice_node_key = ctx.phys_sm.insert(PhysNode::new(output_schema, kind));1713input_streams.insert(PhysStream::first(slice_node_key));1714transformed_exprs.push(ctx.expr_arena.add(AExpr::Column(out_name)));1715},17161717AExpr::Function {1718input: ref inner_exprs,1719function: func @ (IRFunctionExpr::Shift | IRFunctionExpr::ShiftAndFill),1720options: _,1721} => {1722let out_name = unique_column_name();1723let data_col_expr = inner_exprs[0].with_alias(out_name.clone());1724let trans_data_column = build_select_stream_with_ctx(input, &[data_col_expr], ctx)?;1725let trans_offset =1726build_select_stream_with_ctx(input, &[inner_exprs[1].clone()], ctx)?;17271728let trans_fill = if func == IRFunctionExpr::ShiftAndFill {1729let fill_expr = inner_exprs[2].with_alias(out_name.clone());1730Some(build_select_stream_with_ctx(input, &[fill_expr], ctx)?)1731} else {1732None1733};17341735let output_schema = ctx.phys_sm[trans_data_column.node].output_schema.clone();1736let node_key = ctx.phys_sm.insert(PhysNode::new(1737output_schema,1738PhysNodeKind::Shift {1739input: trans_data_column,1740offset: trans_offset,1741fill: trans_fill,1742},1743));17441745input_streams.insert(PhysStream::first(node_key));1746transformed_exprs.push(ctx.expr_arena.add(AExpr::Column(out_name)));1747},17481749AExpr::AnonymousFunction { .. }1750| AExpr::Function { .. }1751| AExpr::Window { .. }1752| AExpr::Gather { .. } => {1753let out_name = unique_column_name();1754fallback_subset.push(ExprIR::new(expr, OutputName::Alias(out_name.clone())));1755transformed_exprs.push(ctx.expr_arena.add(AExpr::Column(out_name)));1756},1757}1758}17591760if !fallback_subset.is_empty() {1761let fallback_node = build_fallback_node_with_ctx(input, &fallback_subset, ctx)?;1762input_streams.insert(PhysStream::first(fallback_node));1763}17641765// Simplify the input nodes (also ensures the original input only occurs1766// once in the zip).1767input_streams = simplify_input_streams(input, input_streams, ctx)?;17681769if input_streams.len() == 1 {1770// No need for any multiplexing/zipping, can directly execute.1771return Ok((input_streams.into_iter().next().unwrap(), transformed_exprs));1772}17731774let zip_inputs = input_streams.into_iter().collect_vec();1775let output_schema = zip_inputs1776.iter()1777.flat_map(|stream| ctx.phys_sm[stream.node].output_schema.iter_fields())1778.collect();1779let zip_kind = PhysNodeKind::Zip {1780inputs: zip_inputs,1781null_extend: false,1782};1783let zip_node = ctx1784.phys_sm1785.insert(PhysNode::new(Arc::new(output_schema), zip_kind));17861787Ok((PhysStream::first(zip_node), transformed_exprs))1788}17891790/// Computes the schema that selecting the given expressions on the input schema1791/// would result in.1792pub fn compute_output_schema(1793input_schema: &Schema,1794exprs: &[ExprIR],1795expr_arena: &Arena<AExpr>,1796) -> PolarsResult<Arc<Schema>> {1797let output_schema: Schema = exprs1798.iter()1799.map(|e| {1800let name = e.output_name().clone();1801let dtype = e1802.dtype(input_schema, expr_arena)?1803.clone()1804.materialize_unknown(true)1805.unwrap();1806PolarsResult::Ok(Field::new(name, dtype))1807})1808.try_collect()?;1809Ok(Arc::new(output_schema))1810}18111812/// Computes the schema that selecting the given expressions on the input node1813/// would result in.1814fn schema_for_select(1815input: PhysStream,1816exprs: &[ExprIR],1817ctx: &mut LowerExprContext,1818) -> PolarsResult<Arc<Schema>> {1819let input_schema = &ctx.phys_sm[input.node].output_schema;1820compute_output_schema(input_schema, exprs, ctx.expr_arena)1821}18221823fn build_select_stream_with_ctx(1824input: PhysStream,1825exprs: &[ExprIR],1826ctx: &mut LowerExprContext,1827) -> PolarsResult<PhysStream> {1828if exprs1829.iter()1830.all(|e| is_input_independent_ctx(e.node(), ctx))1831{1832return Ok(PhysStream::first(build_input_independent_node_with_ctx(1833exprs, ctx,1834)?));1835}18361837// Are we only selecting simple columns, with the same name?1838let all_simple_columns: Option<Vec<PlSmallStr>> = exprs1839.iter()1840.map(|e| match ctx.expr_arena.get(e.node()) {1841AExpr::Column(name) if name == e.output_name() => Some(name.clone()),1842_ => None,1843})1844.collect();18451846if let Some(columns) = all_simple_columns {1847let input_schema = ctx.phys_sm[input.node].output_schema.clone();1848if input_schema.len() == columns.len()1849&& input_schema.iter_names().zip(&columns).all(|(l, r)| l == r)1850{1851// Input node already has the correct schema, just pass through.1852return Ok(input);1853}18541855let output_schema = Arc::new(input_schema.try_project(&columns)?);1856let node_kind = PhysNodeKind::SimpleProjection { input, columns };1857let node_key = ctx.phys_sm.insert(PhysNode::new(output_schema, node_kind));1858return Ok(PhysStream::first(node_key));1859}18601861// Actual lowering is needed.1862let node_exprs = exprs.iter().map(|e| e.node()).collect_vec();1863let (transformed_input, transformed_exprs) = lower_exprs_with_ctx(input, &node_exprs, ctx)?;1864let trans_expr_irs = exprs1865.iter()1866.zip(transformed_exprs)1867.map(|(e, te)| ExprIR::new(te, OutputName::Alias(e.output_name().clone())))1868.collect_vec();1869let output_schema = schema_for_select(transformed_input, &trans_expr_irs, ctx)?;1870let node_kind = PhysNodeKind::Select {1871input: transformed_input,1872selectors: trans_expr_irs,1873extend_original: false,1874};1875let node_key = ctx.phys_sm.insert(PhysNode::new(output_schema, node_kind));1876Ok(PhysStream::first(node_key))1877}18781879/// Lowers an input node plus a set of expressions on that input node to an1880/// equivalent (input node, set of expressions) pair, ensuring that the new set1881/// of expressions can run on the streaming engine.1882///1883/// Ensures that if the input node is transformed it has unique column names.1884pub fn lower_exprs(1885input: PhysStream,1886exprs: &[ExprIR],1887expr_arena: &mut Arena<AExpr>,1888phys_sm: &mut SlotMap<PhysNodeKey, PhysNode>,1889expr_cache: &mut ExprCache,1890ctx: StreamingLowerIRContext,1891) -> PolarsResult<(PhysStream, Vec<ExprIR>)> {1892let mut ctx = LowerExprContext {1893expr_arena,1894phys_sm,1895cache: expr_cache,1896prepare_visualization: ctx.prepare_visualization,1897};1898let node_exprs = exprs.iter().map(|e| e.node()).collect_vec();1899let (transformed_input, transformed_exprs) =1900lower_exprs_with_ctx(input, &node_exprs, &mut ctx)?;1901let trans_expr_irs = exprs1902.iter()1903.zip(transformed_exprs)1904.map(|(e, te)| ExprIR::new(te, OutputName::Alias(e.output_name().clone())))1905.collect_vec();1906Ok((transformed_input, trans_expr_irs))1907}19081909/// Builds a new selection node given an input stream and the expressions to1910/// select for, if needed.1911pub fn build_select_stream(1912input: PhysStream,1913exprs: &[ExprIR],1914expr_arena: &mut Arena<AExpr>,1915phys_sm: &mut SlotMap<PhysNodeKey, PhysNode>,1916expr_cache: &mut ExprCache,1917ctx: StreamingLowerIRContext,1918) -> PolarsResult<PhysStream> {1919let mut ctx = LowerExprContext {1920expr_arena,1921phys_sm,1922cache: expr_cache,1923prepare_visualization: ctx.prepare_visualization,1924};1925build_select_stream_with_ctx(input, exprs, &mut ctx)1926}19271928/// Builds a hstack node given an input stream and the expressions to add.1929pub fn build_hstack_stream(1930input: PhysStream,1931exprs: &[ExprIR],1932expr_arena: &mut Arena<AExpr>,1933phys_sm: &mut SlotMap<PhysNodeKey, PhysNode>,1934expr_cache: &mut ExprCache,1935ctx: StreamingLowerIRContext,1936) -> PolarsResult<PhysStream> {1937let input_schema = &phys_sm[input.node].output_schema;1938if exprs1939.iter()1940.all(|e| is_elementwise_rec_cached(e.node(), expr_arena, expr_cache))1941{1942let mut output_schema = input_schema.as_ref().clone();1943for expr in exprs {1944output_schema.insert(1945expr.output_name().clone(),1946expr.dtype(input_schema, expr_arena)?.clone(),1947);1948}1949let output_schema = Arc::new(output_schema);19501951let selectors = exprs.to_vec();1952let kind = PhysNodeKind::Select {1953input,1954selectors,1955extend_original: true,1956};1957let node_key = phys_sm.insert(PhysNode {1958output_schema,1959kind,1960});19611962Ok(PhysStream::first(node_key))1963} else {1964// We already handled the all-streamable case above, so things get more complicated.1965// For simplicity we just do a normal select with all the original columns prepended.1966let mut selectors = PlIndexMap::with_capacity(input_schema.len() + exprs.len());1967for name in input_schema.iter_names() {1968let col_name = name.clone();1969let col_expr = expr_arena.add(AExpr::Column(col_name.clone()));1970selectors.insert(1971name.clone(),1972ExprIR::new(col_expr, OutputName::ColumnLhs(col_name)),1973);1974}1975for expr in exprs {1976selectors.insert(expr.output_name().clone(), expr.clone());1977}1978let selectors = selectors.into_values().collect_vec();1979build_length_preserving_select_stream(1980input, &selectors, expr_arena, phys_sm, expr_cache, ctx,1981)1982}1983}19841985/// Builds a new selection node given an input stream and the expressions to1986/// select for, if needed. Preserves the length of the input, like in with_columns.1987pub fn build_length_preserving_select_stream(1988input: PhysStream,1989exprs: &[ExprIR],1990expr_arena: &mut Arena<AExpr>,1991phys_sm: &mut SlotMap<PhysNodeKey, PhysNode>,1992expr_cache: &mut ExprCache,1993ctx: StreamingLowerIRContext,1994) -> PolarsResult<PhysStream> {1995let mut ctx = LowerExprContext {1996expr_arena,1997phys_sm,1998cache: expr_cache,1999prepare_visualization: ctx.prepare_visualization,2000};2001let already_length_preserving = exprs2002.iter()2003.any(|expr| is_length_preserving_ctx(expr.node(), &mut ctx));2004let input_schema = &ctx.phys_sm[input.node].output_schema;2005if exprs.is_empty() || input_schema.is_empty() || already_length_preserving {2006return build_select_stream_with_ctx(input, exprs, &mut ctx);2007}20082009// Hacky work-around: append an input column with a temporary name, but2010// remove it from the final selector. This should ensure scalars gets zipped2011// back to the input to broadcast them.2012let tmp_name = unique_column_name();2013let first_col = ctx.expr_arena.add(AExpr::Column(2014input_schema.iter_names_cloned().next().unwrap(),2015));2016let mut tmp_exprs = Vec::with_capacity(exprs.len() + 1);2017tmp_exprs.extend(exprs.iter().cloned());2018tmp_exprs.push(ExprIR::new(first_col, OutputName::Alias(tmp_name.clone())));20192020let out_stream = build_select_stream_with_ctx(input, &tmp_exprs, &mut ctx)?;2021let PhysNodeKind::Select { selectors, .. } = &mut ctx.phys_sm[out_stream.node].kind else {2022unreachable!()2023};2024assert!(selectors.pop().unwrap().output_name() == &tmp_name);2025let out_schema = Arc::make_mut(&mut phys_sm[out_stream.node].output_schema);2026out_schema.shift_remove(tmp_name.as_ref()).unwrap();2027Ok(out_stream)2028}202920302031