Path: blob/main/crates/polars-plan/src/plans/optimizer/sortedness.rs
8424 views
use std::sync::Arc;12use polars_core::chunked_array::cast::CastOptions;3use polars_core::prelude::{FillNullStrategy, PlHashMap, PlHashSet};4use polars_core::schema::Schema;5use polars_core::series::IsSorted;6use polars_utils::arena::{Arena, Node};7use polars_utils::itertools::Itertools;8use polars_utils::pl_str::PlSmallStr;9use polars_utils::unique_id::UniqueId;1011#[cfg(all(feature = "strings", feature = "concat_str"))]12use crate::plans::IRStringFunction;13use crate::plans::{14AExpr, ExprIR, FunctionIR, HintIR, IR, IRFunctionExpr, Sorted, ToFieldContext,15constant_evaluate, into_column,16};1718#[derive(Debug, Clone)]19pub struct IRSorted(pub Arc<[Sorted]>);2021/// Are the keys together sorted in any way?22///23/// Returns the way in which the keys are sorted, if they are sorted.24pub fn are_keys_sorted_any(25ir_sorted: Option<&IRSorted>,26keys: &[ExprIR],27expr_arena: &Arena<AExpr>,28input_schema: &Schema,29) -> Option<Vec<AExprSorted>> {30let mut sortedness = Vec::with_capacity(keys.len());31for (idx, key) in keys.iter().enumerate() {32let s = aexpr_sortedness(33expr_arena.get(key.node()),34expr_arena,35input_schema,36Some(&ir_sorted?.0[idx..]),37)?;38sortedness.push(s);39}40Some(sortedness)41}4243pub fn is_sorted(root: Node, ir_arena: &Arena<IR>, expr_arena: &Arena<AExpr>) -> Option<IRSorted> {44let mut sortedness = PlHashMap::default();45let mut cache_proxy = PlHashMap::default();46let mut amort_passed_columns = PlHashSet::default();4748is_sorted_rec(49root,50ir_arena,51expr_arena,52&mut sortedness,53&mut cache_proxy,54&mut amort_passed_columns,55)56}5758#[recursive::recursive]59fn is_sorted_rec(60root: Node,61ir_arena: &Arena<IR>,62expr_arena: &Arena<AExpr>,63sortedness: &mut PlHashMap<Node, Option<IRSorted>>,64cache_proxy: &mut PlHashMap<UniqueId, Option<IRSorted>>,65amort_passed_columns: &mut PlHashSet<PlSmallStr>,66) -> Option<IRSorted> {67if let Some(s) = sortedness.get(&root) {68return s.clone();69}7071macro_rules! rec {72($node:expr) => {{73is_sorted_rec(74$node,75ir_arena,76expr_arena,77sortedness,78cache_proxy,79amort_passed_columns,80)81}};82}8384sortedness.insert(root, None);8586// @NOTE: Most of the below implementations are very very conservative.87let sorted = match ir_arena.get(root) {88#[cfg(feature = "python")]89IR::PythonScan { .. } => None,90IR::Slice {91input,92offset: _,93len: _,94} => rec!(*input),95IR::Filter {96input,97predicate: _,98} => rec!(*input),99IR::Scan { .. } => None,100IR::DataFrameScan { df, .. } => {101let sorted_cols = df102.columns()103.iter()104.filter_map(|c| match c.is_sorted_flag() {105IsSorted::Not => None,106IsSorted::Ascending => Some(Sorted {107column: c.name().clone(),108descending: Some(false),109nulls_last: Some(c.get(0).is_ok_and(|v| !v.is_null())),110}),111IsSorted::Descending => Some(Sorted {112column: c.name().clone(),113descending: Some(true),114nulls_last: Some(c.get(0).is_ok_and(|v| !v.is_null())),115}),116})117.collect_vec();118(!sorted_cols.is_empty()).then(|| IRSorted(sorted_cols.into()))119},120IR::SimpleProjection { input, columns } => {121let (input, columns) = (*input, columns.clone());122match rec!(input) {123None => None,124Some(v) => {125let first_unsorted_key = v.0.iter().position(|v| !columns.contains(&v.column));126match first_unsorted_key {127None => Some(v),128Some(0) => None,129Some(i) => Some(IRSorted(v.0.iter().take(i).cloned().collect())),130}131},132}133},134IR::Select { input, expr, .. } => {135let input = *input;136let input_sorted = rec!(input);137138if let Some(input_sorted) = &input_sorted {139// We can keep a sorted column if it was kept and not changed.140141amort_passed_columns.clear();142amort_passed_columns.extend(expr.iter().filter_map(|e| {143let column = into_column(e.node(), expr_arena)?;144(column == e.output_name()).then(|| column.clone())145}));146147let first_unkept_key = input_sorted148.0149.iter()150.position(|v| !amort_passed_columns.contains(&v.column));151match first_unkept_key {152None => Some(input_sorted.clone()),153Some(0) => {154let input_schema = ir_arena.get(input).schema(ir_arena);155first_expr_ir_sorted(156expr,157expr_arena,158input_schema.as_ref(),159Some(&input_sorted.0),160)161.map(|s| IRSorted([s].into()))162},163Some(i) => Some(IRSorted(input_sorted.0.iter().take(i).cloned().collect())),164}165} else {166let input_schema = ir_arena.get(input).schema(ir_arena);167first_expr_ir_sorted(expr, expr_arena, input_schema.as_ref(), None)168.map(|s| IRSorted([s].into()))169}170},171IR::HStack { input, exprs, .. } => {172let input = *input;173let input_sorted = rec!(input);174175if let Some(input_sorted) = &input_sorted {176// We can keep a sorted column if it was not overwritten.177178amort_passed_columns.clear();179amort_passed_columns.extend(exprs.iter().filter_map(|e| {180match into_column(e.node(), expr_arena) {181None => Some(e.output_name().clone()),182Some(c) if c == e.output_name() => None,183Some(_) => Some(e.output_name().clone()),184}185}));186187let first_overwritten_key = input_sorted188.0189.iter()190.position(|v| amort_passed_columns.contains(&v.column));191match first_overwritten_key {192None => Some(input_sorted.clone()),193Some(0) => {194let input_schema = ir_arena.get(input).schema(ir_arena);195first_expr_ir_sorted(196exprs,197expr_arena,198input_schema.as_ref(),199Some(&input_sorted.0),200)201.map(|s| IRSorted([s].into()))202},203Some(i) => Some(IRSorted(input_sorted.0.iter().take(i).cloned().collect())),204}205} else {206let input_schema = ir_arena.get(input).schema(ir_arena);207first_expr_ir_sorted(exprs, expr_arena, input_schema.as_ref(), None)208.map(|s| IRSorted([s].into()))209}210},211IR::Sort {212input: _,213by_column,214slice: _,215sort_options,216} => {217let mut s = by_column218.iter()219.map_while(|e| {220into_column(e.node(), expr_arena).map(|c| Sorted {221column: c.clone(),222descending: Some(false),223nulls_last: Some(false),224})225})226.collect::<Vec<_>>();227if sort_options.descending.len() != 1 {228s.iter_mut()229.zip(sort_options.descending.iter())230.for_each(|(s, &d)| s.descending = Some(d));231} else if sort_options.descending[0] {232s.iter_mut().for_each(|s| s.descending = Some(true));233}234if sort_options.nulls_last.len() != 1 {235s.iter_mut()236.zip(sort_options.nulls_last.iter())237.for_each(|(s, &d)| s.nulls_last = Some(d));238} else if sort_options.nulls_last[0] {239s.iter_mut().for_each(|s| s.nulls_last = Some(true));240}241242Some(IRSorted(s.into()))243},244IR::Cache { input, id } => {245let (input, id) = (*input, *id);246if let Some(s) = cache_proxy.get(&id) {247s.clone()248} else {249let s = rec!(input);250cache_proxy.insert(id, s.clone());251s252}253},254IR::GroupBy {255input,256keys,257options,258maintain_order: true,259..260} if !options.is_rolling() && !options.is_dynamic() => {261let input = *input;262let input_sorted = rec!(input)?;263264amort_passed_columns.clear();265amort_passed_columns.extend(keys.iter().filter_map(|e| {266let column = into_column(e.node(), expr_arena)?;267(column == e.output_name()).then(|| column.clone())268}));269270// We can keep a sorted key column if it was kept and not changed.271272let first_unkept_key = input_sorted273.0274.iter()275.position(|v| !amort_passed_columns.contains(&v.column));276match first_unkept_key {277None => Some(input_sorted.clone()),278Some(0) => {279let input_schema = ir_arena.get(input).schema(ir_arena);280first_expr_ir_sorted(keys, expr_arena, input_schema.as_ref(), None)281.map(|s| IRSorted([s].into()))282},283Some(i) => Some(IRSorted(input_sorted.0.iter().take(i).cloned().collect())),284}285},286#[cfg(feature = "dynamic_group_by")]287IR::GroupBy { options, .. } if options.is_rolling() => {288let Some(rolling_options) = &options.rolling else {289unreachable!()290};291Some(IRSorted(292[Sorted {293column: rolling_options.index_column.clone(),294descending: None,295nulls_last: None,296}]297.into(),298))299},300#[cfg(feature = "dynamic_group_by")]301IR::GroupBy { keys, options, .. } if options.is_dynamic() => {302let Some(dynamic_options) = &options.dynamic else {303unreachable!()304};305keys.is_empty().then(|| {306IRSorted(307[Sorted {308column: dynamic_options.index_column.clone(),309descending: None,310nulls_last: None,311}]312.into(),313)314})315},316317IR::GroupBy { .. } => None,318IR::Join { .. } => None,319IR::MapFunction { input, function } => match function {320FunctionIR::Hint(hint) => match hint {321HintIR::Sorted(v) => Some(IRSorted(v.clone())),322#[expect(unreachable_patterns)]323_ => rec!(*input),324},325_ => None,326},327IR::Union { .. } => None,328IR::HConcat { .. } => None,329IR::ExtContext { .. } => None,330IR::Sink { .. } => None,331IR::SinkMultiple { .. } => None,332#[cfg(feature = "merge_sorted")]333IR::MergeSorted { key, .. } => Some(IRSorted(334[Sorted {335column: key.clone(),336descending: None,337nulls_last: None,338}]339.into(),340)),341IR::Distinct { input, options } => {342if !options.maintain_order {343return None;344}345346let input = *input;347rec!(input)348},349IR::Invalid => unreachable!(),350};351352sortedness.insert(root, sorted.clone());353sorted354}355356#[derive(Debug, PartialEq)]357pub struct AExprSorted {358pub descending: Option<bool>,359pub nulls_last: Option<bool>,360}361362fn first_expr_ir_sorted(363exprs: &[ExprIR],364arena: &Arena<AExpr>,365schema: &Schema,366input_sorted: Option<&[Sorted]>,367) -> Option<Sorted> {368exprs.iter().find_map(|e| {369aexpr_sortedness(arena.get(e.node()), arena, schema, input_sorted).map(|s| Sorted {370column: e.output_name().clone(),371descending: s.descending,372nulls_last: s.nulls_last,373})374})375}376377#[recursive::recursive]378pub fn aexpr_sortedness(379aexpr: &AExpr,380arena: &Arena<AExpr>,381schema: &Schema,382input_sorted: Option<&[Sorted]>,383) -> Option<AExprSorted> {384match aexpr {385AExpr::Element => None,386AExpr::Explode { .. } => None,387AExpr::Column(col) => {388let fst = input_sorted?.first()?;389(fst.column == col).then_some(AExprSorted {390descending: fst.descending,391nulls_last: fst.nulls_last,392})393},394#[cfg(feature = "dtype-struct")]395AExpr::StructField(_) => None,396AExpr::Literal(lv) if lv.is_scalar() => Some(AExprSorted {397descending: Some(false),398nulls_last: Some(false),399}),400AExpr::Literal(_) => None,401402AExpr::Len => Some(AExprSorted {403descending: Some(false),404nulls_last: Some(false),405}),406AExpr::Cast {407expr,408dtype,409options: CastOptions::Strict,410} if dtype.is_integer() => {411let expr = arena.get(*expr);412let expr_sortedness = aexpr_sortedness(expr, arena, schema, input_sorted)?;413let input_dtype = expr.to_dtype(&ToFieldContext::new(arena, schema)).ok()?;414if !input_dtype.is_integer() {415return None;416}417Some(expr_sortedness)418},419AExpr::Cast { .. } => None, // @TODO: More casts are allowed420AExpr::Sort { expr: _, options } => Some(AExprSorted {421descending: Some(options.descending),422nulls_last: Some(options.nulls_last),423}),424AExpr::Function {425input,426function,427options: _,428} => function_expr_sortedness(function, input, arena, schema, input_sorted),429AExpr::Filter { input, by: _ }430| AExpr::Slice {431input,432offset: _,433length: _,434} => aexpr_sortedness(arena.get(*input), arena, schema, input_sorted),435436AExpr::BinaryExpr { .. }437| AExpr::Gather { .. }438| AExpr::SortBy { .. }439| AExpr::Agg(_)440| AExpr::Ternary { .. }441| AExpr::AnonymousAgg { .. }442| AExpr::AnonymousFunction { .. }443| AExpr::Eval { .. }444| AExpr::Over { .. } => None,445446#[cfg(feature = "dtype-struct")]447AExpr::StructEval { .. } => None,448449#[cfg(feature = "dynamic_group_by")]450AExpr::Rolling { .. } => None,451}452}453454pub fn function_expr_sortedness(455function: &IRFunctionExpr,456inputs: &[ExprIR],457arena: &Arena<AExpr>,458schema: &Schema,459input_sorted: Option<&[Sorted]>,460) -> Option<AExprSorted> {461macro_rules! rec_ae {462($node:expr) => {{ aexpr_sortedness(arena.get($node), arena, schema, input_sorted) }};463}464465match function {466#[cfg(feature = "rle")]467IRFunctionExpr::RLEID => Some(AExprSorted {468descending: Some(false),469nulls_last: Some(false),470}),471IRFunctionExpr::SetSortedFlag(is_sorted) => match is_sorted {472IsSorted::Ascending => Some(AExprSorted {473descending: Some(false),474nulls_last: None,475}),476IsSorted::Descending => Some(AExprSorted {477descending: Some(true),478nulls_last: None,479}),480IsSorted::Not => None,481},482483IRFunctionExpr::Unique(true)484| IRFunctionExpr::DropNulls485| IRFunctionExpr::DropNans486| IRFunctionExpr::FillNullWithStrategy(487FillNullStrategy::Forward(None) | FillNullStrategy::Backward(None),488) => {489let [e] = inputs else {490return None;491};492493rec_ae!(e.node())494},495#[cfg(feature = "mode")]496IRFunctionExpr::Mode {497maintain_order: true,498} => {499let [e] = inputs else {500return None;501};502503rec_ae!(e.node())504},505506#[cfg(feature = "range")]507IRFunctionExpr::Range(range) => {508use crate::plans::IRRangeFunction as R;509match range {510// `int_range(0, ..., step=1, dtype=UNSIGNED)`511R::IntRange { step: 1, dtype }512if dtype.is_unsigned_integer()513&& constant_evaluate(inputs[0].node(), arena, schema, 0)??514.extract_i64()515.is_ok_and(|v| v == 0) =>516{517Some(AExprSorted {518descending: Some(false),519nulls_last: Some(false),520})521},522523_ => None,524}525},526527IRFunctionExpr::Reverse => {528let [e] = inputs else {529return None;530};531532let mut sortedness = rec_ae!(e.node())?;533534if let Some(d) = &mut sortedness.descending {535*d = !*d;536}537if let Some(n) = &mut sortedness.nulls_last {538*n ^= !*n;539}540Some(sortedness)541},542543#[cfg(all(feature = "strings", feature = "concat_str"))]544IRFunctionExpr::StringExpr(IRStringFunction::ConcatHorizontal {545ignore_nulls: false,546delimiter: _,547}) => {548let [e] = inputs else {549return None;550};551552rec_ae!(e.node())553},554555_ => None,556}557}558559560