Path: blob/main/crates/polars-plan/src/plans/optimizer/count_star.rs
6940 views
use polars_io::cloud::CloudOptions;1use polars_utils::mmap::MemSlice;2use polars_utils::plpath::PlPath;34use super::*;56pub(super) struct CountStar;78impl CountStar {9pub(super) fn new() -> Self {10Self11}12}1314impl CountStar {15// Replace select count(*) from datasource with specialized map function.16pub(super) fn optimize_plan(17&mut self,18lp_arena: &mut Arena<IR>,19expr_arena: &mut Arena<AExpr>,20mut node: Node,21) -> PolarsResult<Option<IR>> {22// New-streaming always puts a sink on top.23if let IR::Sink { input, .. } = lp_arena.get(node) {24node = *input;25}2627// Note: This will be a useful flag later for testing parallel CountLines on CSV.28let use_fast_file_count = match std::env::var("POLARS_FAST_FILE_COUNT_DISPATCH").as_deref()29{30Ok("1") => Some(true),31Ok("0") => Some(false),32Ok(v) => panic!("POLARS_FAST_FILE_COUNT_DISPATCH must be one of ('0', '1'), got: {v}"),33Err(_) => None,34};3536Ok(visit_logical_plan_for_scan_paths(37node,38lp_arena,39expr_arena,40false,41use_fast_file_count,42)43.map(|count_star_expr| {44// MapFunction needs a leaf node, hence we create a dummy placeholder node45let placeholder = IR::DataFrameScan {46df: Arc::new(Default::default()),47schema: Arc::new(Default::default()),48output_schema: None,49};50let placeholder_node = lp_arena.add(placeholder);5152let alp = IR::MapFunction {53input: placeholder_node,54function: FunctionIR::FastCount {55sources: count_star_expr.sources,56scan_type: count_star_expr.scan_type,57cloud_options: count_star_expr.cloud_options,58alias: count_star_expr.alias,59},60};6162lp_arena.replace(count_star_expr.node, alp.clone());63alp64}))65}66}6768struct CountStarExpr {69// Top node of the projection to replace70node: Node,71// Paths to the input files72sources: ScanSources,73cloud_options: Option<CloudOptions>,74// File Type75scan_type: Box<FileScanIR>,76// Column Alias77alias: Option<PlSmallStr>,78}7980// Visit the logical plan and return CountStarExpr with the expr information gathered81// Return None if query is not a simple COUNT(*) FROM SOURCE82fn visit_logical_plan_for_scan_paths(83node: Node,84lp_arena: &Arena<IR>,85expr_arena: &Arena<AExpr>,86inside_union: bool, // Inside union's we do not check for COUNT(*) expression87use_fast_file_count: Option<bool>, // Overrides if Some88) -> Option<CountStarExpr> {89match lp_arena.get(node) {90IR::Union { inputs, .. } => {91enum MutableSources {92Addresses(Vec<PlPath>),93Buffers(Vec<MemSlice>),94}9596let mut scan_type: Option<Box<FileScanIR>> = None;97let mut cloud_options = None;98let mut sources = None;99100for input in inputs {101match visit_logical_plan_for_scan_paths(102*input,103lp_arena,104expr_arena,105true,106use_fast_file_count,107) {108Some(expr) => {109match (expr.sources, &mut sources) {110(111ScanSources::Paths(addrs),112Some(MutableSources::Addresses(mutable_addrs)),113) => mutable_addrs.extend_from_slice(&addrs[..]),114(ScanSources::Paths(addrs), None) => {115sources = Some(MutableSources::Addresses(addrs.to_vec()))116},117(118ScanSources::Buffers(buffers),119Some(MutableSources::Buffers(mutable_buffers)),120) => mutable_buffers.extend_from_slice(&buffers[..]),121(ScanSources::Buffers(buffers), None) => {122sources = Some(MutableSources::Buffers(buffers.to_vec()))123},124_ => return None,125}126127// Take the first Some(_) cloud option128// TODO: Should check the cloud types are the same.129cloud_options = cloud_options.or(expr.cloud_options);130131match &scan_type {132None => scan_type = Some(expr.scan_type),133Some(scan_type) => {134// All scans must be of the same type (e.g. csv / parquet)135if std::mem::discriminant(&**scan_type)136!= std::mem::discriminant(&*expr.scan_type)137{138return None;139}140},141};142},143None => return None,144}145}146Some(CountStarExpr {147sources: match sources {148Some(MutableSources::Addresses(addrs)) => ScanSources::Paths(addrs.into()),149Some(MutableSources::Buffers(buffers)) => ScanSources::Buffers(buffers.into()),150None => ScanSources::default(),151},152scan_type: scan_type.unwrap(),153cloud_options,154node,155alias: None,156})157},158IR::Scan {159scan_type,160sources,161unified_scan_args,162..163} => {164// New-streaming is generally on par for all except CSV (see https://github.com/pola-rs/polars/pull/22363).165// In the future we can potentially remove the dedicated count codepaths.166167let use_fast_file_count = use_fast_file_count.unwrap_or(match scan_type.as_ref() {168#[cfg(feature = "csv")]169FileScanIR::Csv { .. } => true,170_ => false,171});172173if use_fast_file_count {174Some(CountStarExpr {175sources: sources.clone(),176scan_type: scan_type.clone(),177cloud_options: unified_scan_args.cloud_options.clone(),178node,179alias: None,180})181} else {182None183}184},185// A union can insert a simple projection to ensure all projections align.186// We can ignore that if we are inside a count star.187IR::SimpleProjection { input, .. } if inside_union => visit_logical_plan_for_scan_paths(188*input,189lp_arena,190expr_arena,191false,192use_fast_file_count,193),194IR::Select { input, expr, .. } => {195if expr.len() == 1 {196let (valid, alias) = is_valid_count_expr(&expr[0], expr_arena);197if valid || inside_union {198return visit_logical_plan_for_scan_paths(199*input,200lp_arena,201expr_arena,202false,203use_fast_file_count,204)205.map(|mut expr| {206expr.alias = alias;207expr.node = node;208expr209});210}211}212None213},214_ => None,215}216}217218fn is_valid_count_expr(e: &ExprIR, expr_arena: &Arena<AExpr>) -> (bool, Option<PlSmallStr>) {219match expr_arena.get(e.node()) {220AExpr::Len => (true, e.get_alias().cloned()),221_ => (false, None),222}223}224225226