Path: blob/main/crates/polars-plan/src/plans/optimizer/slice_pushdown_lp.rs
6940 views
use polars_core::prelude::*;1use polars_utils::idx_vec::UnitVec;2use polars_utils::slice_enum::Slice;3use recursive::recursive;45use crate::prelude::*;67mod inner {8use polars_utils::arena::Node;9use polars_utils::idx_vec::UnitVec;10use polars_utils::unitvec;1112pub struct SlicePushDown {13#[expect(unused)]14pub new_streaming: bool,15scratch: UnitVec<Node>,16pub(super) maintain_errors: bool,17}1819impl SlicePushDown {20pub fn new(maintain_errors: bool, new_streaming: bool) -> Self {21Self {22new_streaming,23scratch: unitvec![],24maintain_errors,25}26}2728/// Returns shared scratch space after clearing.29pub fn empty_nodes_scratch_mut(&mut self) -> &mut UnitVec<Node> {30self.scratch.clear();31&mut self.scratch32}33}34}3536pub(super) use inner::SlicePushDown;3738#[derive(Copy, Clone, Debug)]39struct State {40offset: i64,41len: IdxSize,42}4344impl State {45fn to_slice_enum(self) -> Slice {46let offset = self.offset;47let len: usize = usize::try_from(self.len).unwrap();4849(offset, len).into()50}51}5253/// Can push down slice when:54/// * all projections are elementwise55/// * at least 1 projection is based on a column (for height broadcast)56/// * projections not based on any column project as scalars57///58/// Returns (can_pushdown, can_pushdown_and_any_expr_has_column)59fn can_pushdown_slice_past_projections(60exprs: &[ExprIR],61arena: &Arena<AExpr>,62scratch: &mut UnitVec<Node>,63maintain_errors: bool,64) -> (bool, bool) {65scratch.clear();6667let mut can_pushdown_and_any_expr_has_column = false;6869for expr_ir in exprs.iter() {70scratch.push(expr_ir.node());7172// # "has_column"73// `select(c = Literal([1, 2, 3])).slice(0, 0)` must block slice pushdown,74// because `c` projects to a height independent from the input height. We check75// this by observing that `c` does not have any columns in its input nodes.76//77// TODO: Simply checking that a column node is present does not handle e.g.:78// `select(c = Literal([1, 2, 3]).is_in(col(a)))`, for functions like `is_in`,79// `str.contains`, `str.contains_any` etc. - observe a column node is present80// but the output height is not dependent on it.81let mut has_column = false;82let mut literals_all_scalar = true;8384let mut pd_group = ExprPushdownGroup::Pushable;8586while let Some(node) = scratch.pop() {87let ae = arena.get(node);8889// We re-use the logic from predicate pushdown, as slices can be seen as a form of filtering.90// But we also do some bookkeeping here specific to slice pushdown.9192match ae {93AExpr::Column(_) => has_column = true,94AExpr::Literal(v) => literals_all_scalar &= v.is_scalar(),95_ => {},96}9798if pd_group99.update_with_expr(scratch, ae, arena)100.blocks_pushdown(maintain_errors)101{102return (false, false);103}104}105106// If there is no column then all literals must be scalar107if !(has_column || literals_all_scalar) {108return (false, false);109}110111can_pushdown_and_any_expr_has_column |= has_column112}113114(true, can_pushdown_and_any_expr_has_column)115}116117impl SlicePushDown {118// slice will be done at this node if we found any119// we also stop optimization120fn no_pushdown_finish_opt(121&self,122lp: IR,123state: Option<State>,124lp_arena: &mut Arena<IR>,125) -> PolarsResult<IR> {126match state {127Some(state) => {128let input = lp_arena.add(lp);129130let lp = IR::Slice {131input,132offset: state.offset,133len: state.len,134};135Ok(lp)136},137None => Ok(lp),138}139}140141/// slice will be done at this node, but we continue optimization142fn no_pushdown_restart_opt(143&mut self,144lp: IR,145state: Option<State>,146lp_arena: &mut Arena<IR>,147expr_arena: &mut Arena<AExpr>,148) -> PolarsResult<IR> {149let inputs = lp.get_inputs();150151let new_inputs = inputs152.into_iter()153.map(|node| {154let alp = lp_arena.take(node);155// No state, so we do not push down the slice here.156let state = None;157let alp = self.pushdown(alp, state, lp_arena, expr_arena)?;158lp_arena.replace(node, alp);159Ok(node)160})161.collect::<PolarsResult<UnitVec<_>>>()?;162let lp = lp.with_inputs(new_inputs);163164self.no_pushdown_finish_opt(lp, state, lp_arena)165}166167/// slice will be pushed down.168fn pushdown_and_continue(169&mut self,170lp: IR,171state: Option<State>,172lp_arena: &mut Arena<IR>,173expr_arena: &mut Arena<AExpr>,174) -> PolarsResult<IR> {175let inputs = lp.get_inputs();176177let new_inputs = inputs178.into_iter()179.map(|node| {180let alp = lp_arena.take(node);181let alp = self.pushdown(alp, state, lp_arena, expr_arena)?;182lp_arena.replace(node, alp);183Ok(node)184})185.collect::<PolarsResult<UnitVec<_>>>()?;186Ok(lp.with_inputs(new_inputs))187}188189#[recursive]190fn pushdown(191&mut self,192lp: IR,193state: Option<State>,194lp_arena: &mut Arena<IR>,195expr_arena: &mut Arena<AExpr>,196) -> PolarsResult<IR> {197use IR::*;198199match (lp, state) {200#[cfg(feature = "python")]201(PythonScan {202mut options,203},204// TODO! we currently skip slice pushdown if there is a predicate.205// we can modify the readers to only limit after predicates have been applied206Some(state)) if state.offset == 0 && matches!(options.predicate, PythonPredicate::None) => {207options.n_rows = Some(state.len as usize);208let lp = PythonScan {209options,210};211Ok(lp)212}213214(Scan {215sources,216file_info,217hive_parts,218output_schema,219mut unified_scan_args,220predicate,221scan_type,222}, Some(state)) if predicate.is_none() && match &*scan_type {223#[cfg(feature = "parquet")]224FileScanIR::Parquet { .. } => true,225226#[cfg(feature = "ipc")]227FileScanIR::Ipc { .. } => true,228229#[cfg(feature = "csv")]230FileScanIR::Csv { .. } => true,231232#[cfg(feature = "json")]233FileScanIR::NDJson { .. } => true,234235#[cfg(feature = "python")]236FileScanIR::PythonDataset { .. } => true,237238// TODO: This can be `true` after Anonymous scan dispatches to new-streaming.239FileScanIR::Anonymous { .. } => state.offset == 0,240} => {241unified_scan_args.pre_slice = Some(state.to_slice_enum());242243let lp = Scan {244sources,245file_info,246hive_parts,247output_schema,248scan_type,249unified_scan_args,250predicate,251};252253Ok(lp)254},255256(DataFrameScan {df, schema, output_schema, }, Some(state)) => {257let df = df.slice(state.offset, state.len as usize);258let lp = DataFrameScan {259df: Arc::new(df),260schema,261output_schema,262};263Ok(lp)264}265(Union {mut inputs, mut options }, Some(state)) => {266if state.offset == 0 {267for input in &mut inputs {268let input_lp = lp_arena.take(*input);269let input_lp = self.pushdown(input_lp, Some(state), lp_arena, expr_arena)?;270lp_arena.replace(*input, input_lp);271}272}273options.slice = Some((state.offset, state.len as usize));274let lp = Union {inputs, options};275Ok(lp)276},277(Join {278input_left,279input_right,280schema,281left_on,282right_on,283mut options284}, Some(state)) if !matches!(options.options, Some(JoinTypeOptionsIR::CrossAndFilter { .. })) => {285// first restart optimization in both inputs and get the updated LP286let lp_left = lp_arena.take(input_left);287let lp_left = self.pushdown(lp_left, None, lp_arena, expr_arena)?;288let input_left = lp_arena.add(lp_left);289290let lp_right = lp_arena.take(input_right);291let lp_right = self.pushdown(lp_right, None, lp_arena, expr_arena)?;292let input_right = lp_arena.add(lp_right);293294// then assign the slice state to the join operation295296let mut_options = Arc::make_mut(&mut options);297mut_options.args.slice = Some((state.offset, state.len as usize));298299Ok(Join {300input_left,301input_right,302schema,303left_on,304right_on,305options306})307}308(GroupBy { input, keys, aggs, schema, apply, maintain_order, mut options }, Some(state)) => {309// first restart optimization in inputs and get the updated LP310let input_lp = lp_arena.take(input);311let input_lp = self.pushdown(input_lp, None, lp_arena, expr_arena)?;312let input= lp_arena.add(input_lp);313314let mut_options= Arc::make_mut(&mut options);315mut_options.slice = Some((state.offset, state.len as usize));316317Ok(GroupBy {318input,319keys,320aggs,321schema,322apply,323maintain_order,324options325})326}327(Distinct {input, mut options}, Some(state)) => {328// first restart optimization in inputs and get the updated LP329let input_lp = lp_arena.take(input);330let input_lp = self.pushdown(input_lp, None, lp_arena, expr_arena)?;331let input= lp_arena.add(input_lp);332options.slice = Some((state.offset, state.len as usize));333Ok(Distinct {334input,335options,336})337}338(Sort {input, by_column, mut slice,339sort_options}, Some(state)) => {340// first restart optimization in inputs and get the updated LP341let input_lp = lp_arena.take(input);342let input_lp = self.pushdown(input_lp, None, lp_arena, expr_arena)?;343let input= lp_arena.add(input_lp);344345slice = Some((state.offset, state.len as usize));346Ok(Sort {347input,348by_column,349slice,350sort_options351})352}353(Slice {354input,355offset,356mut len357}, Some(outer_slice)) => {358let alp = lp_arena.take(input);359360// Both are positive, can combine into a single slice.361if outer_slice.offset >= 0 && offset >= 0 {362let state = State {363offset: offset.checked_add(outer_slice.offset).unwrap(),364len: if len as i64 > outer_slice.offset {365(len - outer_slice.offset as IdxSize).min(outer_slice.len)366} else {3670368},369};370return self.pushdown(alp, Some(state), lp_arena, expr_arena);371}372373// If offset is negative the length can never be greater than it.374if offset < 0 {375if len as i64 > -offset {376len = (-offset) as IdxSize;377}378}379380// Both are negative, can also combine (but not so simply).381if outer_slice.offset < 0 && offset < 0 {382let inner_start_rel_end = offset;383let inner_stop_rel_end = inner_start_rel_end + len as i64;384let naive_outer_start_rel_end = inner_stop_rel_end + outer_slice.offset;385let naive_outer_stop_rel_end = naive_outer_start_rel_end + outer_slice.len as i64;386let clamped_outer_start_rel_end = naive_outer_start_rel_end.max(inner_start_rel_end);387let clamped_outer_stop_rel_end = naive_outer_stop_rel_end.max(clamped_outer_start_rel_end);388389let state = State {390offset: clamped_outer_start_rel_end,391len: (clamped_outer_stop_rel_end - clamped_outer_start_rel_end) as IdxSize,392};393return self.pushdown(alp, Some(state), lp_arena, expr_arena);394}395396let inner_slice = Some(State { offset, len });397let lp = self.pushdown(alp, inner_slice, lp_arena, expr_arena)?;398let input = lp_arena.add(lp);399Ok(Slice {400input,401offset: outer_slice.offset,402len: outer_slice.len403})404}405(Slice {406input,407offset,408mut len409}, None) => {410let alp = lp_arena.take(input);411412// If offset is negative the length can never be greater than it.413if offset < 0 {414if len as i64 > -offset {415len = (-offset) as IdxSize;416}417}418419let state = Some(State {420offset,421len422});423self.pushdown(alp, state, lp_arena, expr_arena)424}425// [Do not pushdown] boundary426// here we do not pushdown.427// we reset the state and then start the optimization again428m @ (Filter { .. }, _)429// other blocking nodes430| m @ (DataFrameScan {..}, _)431| m @ (Sort {..}, _)432| m @ (MapFunction {function: FunctionIR::Explode {..}, ..}, _)433| m @ (Cache {..}, _)434| m @ (Distinct {..}, _)435| m @ (GroupBy{..},_)436// blocking in streaming437| m @ (Join{..},_)438=> {439let (lp, state) = m;440self.no_pushdown_restart_opt(lp, state, lp_arena, expr_arena)441},442#[cfg(feature = "pivot")]443m @ (MapFunction {function: FunctionIR::Unpivot {..}, ..}, _) => {444let (lp, state) = m;445self.no_pushdown_restart_opt(lp, state, lp_arena, expr_arena)446},447// [Pushdown]448(MapFunction {input, function}, _) if function.allow_predicate_pd() => {449let lp = MapFunction {input, function};450self.pushdown_and_continue(lp, state, lp_arena, expr_arena)451},452// [NO Pushdown]453m @ (MapFunction {..}, _) => {454let (lp, state) = m;455self.no_pushdown_restart_opt(lp, state, lp_arena, expr_arena)456}457// [Pushdown]458// these nodes will be pushed down.459// State is None, we can continue460m @ (Select {..}, None)461| m @ (HStack {..}, None)462| m @ (SimpleProjection {..}, _)463=> {464let (lp, state) = m;465self.pushdown_and_continue(lp, state, lp_arena, expr_arena)466}467// there is state, inspect the projection to determine how to deal with it468(Select {input, expr, schema, options}, Some(_)) => {469let maintain_errors = self.maintain_errors;470if can_pushdown_slice_past_projections(&expr, expr_arena, self.empty_nodes_scratch_mut(), maintain_errors).1 {471let lp = Select {input, expr, schema, options};472self.pushdown_and_continue(lp, state, lp_arena, expr_arena)473}474// don't push down slice, but restart optimization475else {476let lp = Select {input, expr, schema, options};477self.no_pushdown_restart_opt(lp, state, lp_arena, expr_arena)478}479}480(HStack {input, exprs, schema, options}, _) => {481let maintain_errors = self.maintain_errors;482let (can_pushdown, can_pushdown_and_any_expr_has_column) = can_pushdown_slice_past_projections(&exprs, expr_arena, self.empty_nodes_scratch_mut(), maintain_errors);483484if can_pushdown_and_any_expr_has_column || (485// If the schema length is greater then an input column is being projected, so486// the exprs in with_columns do not need to have an input column name.487schema.len() > exprs.len() && can_pushdown488)489{490let lp = HStack {input, exprs, schema, options};491self.pushdown_and_continue(lp, state, lp_arena, expr_arena)492}493// don't push down slice, but restart optimization494else {495let lp = HStack {input, exprs, schema, options};496self.no_pushdown_restart_opt(lp, state, lp_arena, expr_arena)497}498}499(HConcat {inputs, schema, options}, _) => {500// Slice can always be pushed down for horizontal concatenation501let lp = HConcat {inputs, schema, options};502self.pushdown_and_continue(lp, state, lp_arena, expr_arena)503}504(lp @ Sink { .. }, _) | (lp @ SinkMultiple { .. }, _) => {505// Slice can always be pushed down for sinks506self.pushdown_and_continue(lp, state, lp_arena, expr_arena)507}508(catch_all, state) => {509self.no_pushdown_finish_opt(catch_all, state, lp_arena)510}511}512}513514pub fn optimize(515&mut self,516logical_plan: IR,517lp_arena: &mut Arena<IR>,518expr_arena: &mut Arena<AExpr>,519) -> PolarsResult<IR> {520self.pushdown(logical_plan, None, lp_arena, expr_arena)521}522}523524525