Path: blob/main/crates/polars-lazy/src/tests/optimization_checks.rs
6939 views
use super::*;12#[cfg(feature = "parquet")]3pub(crate) fn row_index_at_scan(q: LazyFrame) -> bool {4let (mut expr_arena, mut lp_arena) = get_arenas();5let lp = q.optimize(&mut lp_arena, &mut expr_arena).unwrap();67lp_arena.iter(lp).any(|(_, lp)| {8if let IR::Scan {9unified_scan_args, ..10} = lp11{12unified_scan_args.row_index.is_some()13} else {14false15}16})17}1819pub(crate) fn predicate_at_scan(q: LazyFrame) -> bool {20let (mut expr_arena, mut lp_arena) = get_arenas();21let lp = q.optimize(&mut lp_arena, &mut expr_arena).unwrap();2223lp_arena.iter(lp).any(|(_, lp)| match lp {24IR::Filter { input, .. } => {25matches!(lp_arena.get(*input), IR::DataFrameScan { .. })26},27IR::Scan {28predicate: Some(_), ..29} => true,30_ => false,31})32}3334pub(crate) fn predicate_at_all_scans(q: LazyFrame) -> bool {35let (mut expr_arena, mut lp_arena) = get_arenas();36let lp = q.optimize(&mut lp_arena, &mut expr_arena).unwrap();3738lp_arena.iter(lp).all(|(_, lp)| match lp {39IR::Filter { input, .. } => {40matches!(lp_arena.get(*input), IR::DataFrameScan { .. })41},42IR::Scan {43predicate: Some(_), ..44} => true,45_ => false,46})47}4849#[cfg(any(feature = "parquet", feature = "csv"))]50fn slice_at_scan(q: LazyFrame) -> bool {51let (mut expr_arena, mut lp_arena) = get_arenas();52let lp = q.optimize(&mut lp_arena, &mut expr_arena).unwrap();53lp_arena.iter(lp).any(|(_, lp)| {54use IR::*;55match lp {56Scan {57unified_scan_args, ..58} => unified_scan_args.pre_slice.is_some(),59_ => false,60}61})62}6364#[test]65fn test_pred_pd_1() -> PolarsResult<()> {66let df = fruits_cars();6768let q = df69.clone()70.lazy()71.select([col("A"), col("B")])72.filter(col("A").gt(lit(1)));7374assert!(predicate_at_scan(q));7576// Check if we understand that we can unwrap the alias.77let q = df78.clone()79.lazy()80.select([col("A").alias("C"), col("B")])81.filter(col("C").gt(lit(1)));8283assert!(predicate_at_scan(q));8485// Check if we pass hstack.86let q = df87.lazy()88.with_columns([col("A").alias("C"), col("B")])89.filter(col("B").gt(lit(1)));9091assert!(predicate_at_scan(q));9293Ok(())94}9596#[test]97fn test_no_left_join_pass() -> PolarsResult<()> {98let df1 = df![99"foo" => ["abc", "def", "ghi"],100"idx1" => [0, 0, 1],101]?;102let df2 = df![103"bar" => [5, 6],104"idx2" => [0, 1],105]?;106107let out = df1108.lazy()109.join(110df2.lazy(),111[col("idx1")],112[col("idx2")],113JoinType::Left.into(),114)115.filter(col("bar").eq(lit(5i32)))116.collect()?;117118let expected = df![119"foo" => ["abc", "def"],120"idx1" => [0, 0],121"bar" => [5, 5],122]?;123124assert!(out.equals(&expected));125Ok(())126}127128#[test]129#[cfg(feature = "parquet")]130pub fn test_simple_slice() -> PolarsResult<()> {131let _guard = SINGLE_LOCK.lock().unwrap();132let q = scan_foods_parquet(false).limit(3);133134assert!(slice_at_scan(q.clone()));135let out = q.collect()?;136assert_eq!(out.height(), 3);137138let q = scan_foods_parquet(false)139.select([col("category"), col("calories").alias("bar")])140.limit(3);141assert!(slice_at_scan(q.clone()));142let out = q.collect()?;143assert_eq!(out.height(), 3);144145Ok(())146}147148#[test]149#[cfg(feature = "parquet")]150#[cfg(feature = "cse")]151pub fn test_slice_pushdown_join() -> PolarsResult<()> {152let _guard = SINGLE_LOCK.lock().unwrap();153let q1 = scan_foods_parquet(false).limit(3);154let q2 = scan_foods_parquet(false);155156let q = q1157.join(158q2,159[col("category")],160[col("category")],161JoinType::Left.into(),162)163.slice(1, 3)164// this inserts a cache and blocks slice pushdown165.with_comm_subplan_elim(false);166// test if optimization continued beyond the join node167assert!(slice_at_scan(q.clone()));168169let (mut expr_arena, mut lp_arena) = get_arenas();170let lp = q.clone().optimize(&mut lp_arena, &mut expr_arena).unwrap();171assert!(lp_arena.iter(lp).all(|(_, lp)| {172use IR::*;173match lp {174Join { options, .. } => options.args.slice == Some((1, 3)),175Slice { .. } => false,176_ => true,177}178}));179let out = q.collect()?;180assert_eq!(out.shape(), (3, 7));181182Ok(())183}184185#[test]186#[cfg(feature = "parquet")]187pub fn test_slice_pushdown_group_by() -> PolarsResult<()> {188let _guard = SINGLE_LOCK.lock().unwrap();189let q = scan_foods_parquet(false).limit(100);190191let q = q192.group_by([col("category")])193.agg([col("calories").sum()])194.slice(1, 3);195196// test if optimization continued beyond the group_by node197assert!(slice_at_scan(q.clone()));198199let (mut expr_arena, mut lp_arena) = get_arenas();200let lp = q.clone().optimize(&mut lp_arena, &mut expr_arena).unwrap();201assert!(lp_arena.iter(lp).all(|(_, lp)| {202use IR::*;203match lp {204GroupBy { options, .. } => options.slice == Some((1, 3)),205Slice { .. } => false,206_ => true,207}208}));209let out = q.collect()?;210assert_eq!(out.shape(), (3, 2));211212Ok(())213}214215#[test]216#[cfg(feature = "parquet")]217pub fn test_slice_pushdown_sort() -> PolarsResult<()> {218let _guard = SINGLE_LOCK.lock().unwrap();219let q = scan_foods_parquet(false).limit(100);220221let q = q222.sort(["category"], SortMultipleOptions::default())223.slice(1, 3);224225// test if optimization continued beyond the sort node226assert!(slice_at_scan(q.clone()));227228let (mut expr_arena, mut lp_arena) = get_arenas();229let lp = q.clone().optimize(&mut lp_arena, &mut expr_arena).unwrap();230assert!(lp_arena.iter(lp).all(|(_, lp)| {231use IR::*;232match lp {233Sort { slice, .. } => *slice == Some((1, 3)),234Slice { .. } => false,235_ => true,236}237}));238let out = q.collect()?;239assert_eq!(out.shape(), (3, 4));240241Ok(())242}243244#[test]245#[cfg(feature = "dtype-i16")]246pub fn test_predicate_block_cast() -> PolarsResult<()> {247let df = df![248"value" => [10, 20, 30, 40]249]?;250251let lf1 = df252.clone()253.lazy()254.with_column(col("value").cast(DataType::Int16) * lit(0.1).cast(DataType::Float32))255.filter(col("value").lt(lit(2.5f32)));256257let lf2 = df258.lazy()259.select([col("value").cast(DataType::Int16) * lit(0.1).cast(DataType::Float32)])260.filter(col("value").lt(lit(2.5f32)));261262for lf in [lf1, lf2] {263assert!(!predicate_at_scan(lf.clone()));264265let out = lf.collect()?;266let s = out.column("value").unwrap();267assert_eq!(268s,269&Column::new(PlSmallStr::from_static("value"), [1.0f32, 2.0])270);271}272273Ok(())274}275276#[test]277fn test_lazy_filter_and_rename() {278let df = load_df();279let lf = df280.clone()281.lazy()282.rename(["a"], ["x"], true)283.filter(col("x").map(284|s: Column| Ok(s.as_materialized_series().gt(3)?.into_column()),285|_, f| Ok(Field::new(f.name().clone(), DataType::Boolean)),286))287.select([col("x")]);288289let correct = df! {290"x" => &[4, 5]291}292.unwrap();293assert!(lf.collect().unwrap().equals(&correct));294295// now we check if the column is rename or added when we don't select296let lf = df.lazy().rename(["a"], ["x"], true).filter(col("x").map(297|s: Column| Ok(s.as_materialized_series().gt(3)?.into_column()),298|_, f| Ok(Field::new(f.name().clone(), DataType::Boolean)),299));300// the rename function should not interfere with the predicate pushdown301assert!(predicate_at_scan(lf.clone()));302303assert_eq!(lf.collect().unwrap().get_column_names(), &["x", "b", "c"]);304}305306#[test]307fn test_with_row_index_opts() -> PolarsResult<()> {308let df = df![309"a" => [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]310]?;311312let out = df313.clone()314.lazy()315.with_row_index("index", None)316.tail(5)317.collect()?;318let expected = df![319"index" => [5 as IdxSize, 6, 7, 8, 9],320"a" => [5, 6, 7, 8, 9],321]?;322323assert!(out.equals(&expected));324let out = df325.clone()326.lazy()327.with_row_index("index", None)328.slice(1, 2)329.collect()?;330assert_eq!(331out.column("index")?332.idx()?333.into_no_null_iter()334.collect::<Vec<_>>(),335&[1, 2]336);337338let out = df339.clone()340.lazy()341.with_row_index("index", None)342.filter(col("a").eq(lit(3i32)))343.collect()?;344assert_eq!(345out.column("index")?346.idx()?347.into_no_null_iter()348.collect::<Vec<_>>(),349&[3]350);351352let out = df353.clone()354.lazy()355.slice(1, 2)356.with_row_index("index", None)357.collect()?;358assert_eq!(359out.column("index")?360.idx()?361.into_no_null_iter()362.collect::<Vec<_>>(),363&[0, 1]364);365366let out = df367.lazy()368.filter(col("a").eq(lit(3i32)))369.with_row_index("index", None)370.collect()?;371assert_eq!(372out.column("index")?373.idx()?374.into_no_null_iter()375.collect::<Vec<_>>(),376&[0]377);378379Ok(())380}381382#[cfg(all(feature = "concat_str", feature = "strings"))]383#[test]384fn test_string_addition_to_concat_str() -> PolarsResult<()> {385let df = df![386"a"=> ["a"],387"b"=> ["b"],388]?;389390let q = df391.lazy()392.select([lit("foo") + col("a") + col("b") + lit("bar")]);393394let (mut expr_arena, mut lp_arena) = get_arenas();395let root = q.clone().optimize(&mut lp_arena, &mut expr_arena)?;396let lp = lp_arena.get(root);397let e = lp.exprs().next().unwrap();398if let AExpr::Function { input, .. } = expr_arena.get(e.node()) {399// the concat_str has the 4 expressions as input400assert_eq!(input.len(), 4);401} else {402panic!()403}404405let out = q.collect()?;406let s = out.column("literal")?;407assert_eq!(s.get(0)?, AnyValue::String("fooabbar"));408409Ok(())410}411#[test]412fn test_with_column_prune() -> PolarsResult<()> {413// don't414let df = df![415"c0" => [0],416"c1" => [0],417"c2" => [0],418]?;419let (mut expr_arena, mut lp_arena) = get_arenas();420421// only a single expression pruned and only one column selection422let q = df423.clone()424.lazy()425.with_columns([col("c0"), col("c1").alias("c4")])426.select([col("c1"), col("c4")]);427let lp = q.optimize(&mut lp_arena, &mut expr_arena).unwrap();428lp_arena.iter(lp).for_each(|(_, lp)| {429use IR::*;430match lp {431DataFrameScan { output_schema, .. } => {432let projection = output_schema.as_ref().unwrap();433assert_eq!(projection.len(), 1);434let name = projection.get_at_index(0).unwrap().0;435assert_eq!(name, "c1");436},437HStack { exprs, .. } => {438assert_eq!(exprs.len(), 1);439},440_ => {},441};442});443444// whole `with_columns` pruned445let mut q = df.lazy().with_column(col("c0")).select([col("c1")]);446447let lp = q.clone().optimize(&mut lp_arena, &mut expr_arena).unwrap();448449// check if with_column is pruned450assert!(lp_arena.iter(lp).all(|(_, lp)| {451use IR::*;452453matches!(lp, SimpleProjection { .. } | DataFrameScan { .. })454}));455assert_eq!(456q.collect_schema().unwrap().as_ref(),457&Schema::from_iter([Field::new(PlSmallStr::from_static("c1"), DataType::Int32)])458);459Ok(())460}461462#[test]463#[cfg(feature = "csv")]464fn test_slice_at_scan_group_by() -> PolarsResult<()> {465let ldf = scan_foods_csv();466467// this tests if slice pushdown restarts aggregation nodes (it did not)468let q = ldf469.slice(0, 5)470.filter(col("calories").lt(lit(10)))471.group_by([col("calories")])472.agg([col("fats_g").first()])473.select([col("fats_g")]);474475assert!(slice_at_scan(q));476Ok(())477}478479#[test]480fn test_flatten_unions() -> PolarsResult<()> {481let (mut expr_arena, mut lp_arena) = get_arenas();482483let lf = df! {484"a" => [1,2,3,4,5],485}486.unwrap()487.lazy();488489let args = UnionArgs {490rechunk: false,491parallel: true,492..Default::default()493};494let lf2 = concat(&[lf.clone(), lf.clone()], args).unwrap();495let lf3 = concat(&[lf.clone(), lf.clone(), lf], args).unwrap();496let lf4 = concat(&[lf2, lf3], args).unwrap();497let root = lf4.optimize(&mut lp_arena, &mut expr_arena).unwrap();498let lp = lp_arena.get(root);499match lp {500IR::Union { inputs, .. } => {501// we make sure that the nested unions are flattened into a single union502assert_eq!(inputs.len(), 5);503},504_ => panic!(),505}506Ok(())507}508509fn num_occurrences(s: &str, needle: &str) -> usize {510let mut i = 0;511let mut num = 0;512513while let Some(n) = s[i..].find(needle) {514i += n + 1;515num += 1;516}517518num519}520521#[test]522fn test_cluster_with_columns() -> Result<(), Box<dyn std::error::Error>> {523use polars_core::prelude::*;524525let df = df!("foo" => &[0.5, 1.7, 3.2],526"bar" => &[4.1, 1.5, 9.2])?;527528let df = df529.lazy()530.without_optimizations()531.with_cluster_with_columns(true)532.with_columns([col("foo") * lit(2.0)])533.with_columns([col("bar") / lit(1.5)]);534535let unoptimized = df.clone().to_alp().unwrap();536let optimized = df.to_alp_optimized().unwrap();537538let unoptimized = unoptimized.describe();539let optimized = optimized.describe();540541println!("\n---\n");542543println!("Unoptimized:\n{unoptimized}",);544println!("\n---\n");545println!("Optimized:\n{optimized}");546547assert_eq!(num_occurrences(&unoptimized, "WITH_COLUMNS"), 2);548assert_eq!(num_occurrences(&optimized, "WITH_COLUMNS"), 1);549550Ok(())551}552553#[test]554fn test_cluster_with_columns_dependency() -> Result<(), Box<dyn std::error::Error>> {555use polars_core::prelude::*;556557let df = df!("foo" => &[0.5, 1.7, 3.2],558"bar" => &[4.1, 1.5, 9.2])?;559560let df = df561.lazy()562.without_optimizations()563.with_cluster_with_columns(true)564.with_columns([col("foo").alias("buzz")])565.with_columns([col("buzz")]);566567let unoptimized = df.clone().to_alp().unwrap();568let optimized = df.to_alp_optimized().unwrap();569570let unoptimized = unoptimized.describe();571let optimized = optimized.describe();572573println!("\n---\n");574575println!("Unoptimized:\n{unoptimized}",);576println!("\n---\n");577println!("Optimized:\n{optimized}");578579assert_eq!(num_occurrences(&unoptimized, "WITH_COLUMNS"), 2);580assert_eq!(num_occurrences(&optimized, "WITH_COLUMNS"), 2);581582Ok(())583}584585#[test]586fn test_cluster_with_columns_partial() -> Result<(), Box<dyn std::error::Error>> {587use polars_core::prelude::*;588589let df = df!("foo" => &[0.5, 1.7, 3.2],590"bar" => &[4.1, 1.5, 9.2])?;591592let df = df593.lazy()594.without_optimizations()595.with_cluster_with_columns(true)596.with_columns([col("foo").alias("buzz")])597.with_columns([col("buzz"), col("foo") * lit(2.0)]);598599let unoptimized = df.clone().to_alp().unwrap();600let optimized = df.to_alp_optimized().unwrap();601602let unoptimized = unoptimized.describe();603let optimized = optimized.describe();604605println!("\n---\n");606607println!("Unoptimized:\n{unoptimized}",);608println!("\n---\n");609println!("Optimized:\n{optimized}");610611assert!(unoptimized.contains(r#"[col("buzz"), [(col("foo")) * (2.0)]]"#));612assert!(unoptimized.contains(r#"[col("foo").alias("buzz")]"#));613assert!(optimized.contains(r#"[col("buzz")]"#));614assert!(optimized.contains(r#"[col("foo").alias("buzz"), [(col("foo")) * (2.0)]]"#));615616Ok(())617}618619#[test]620fn test_cluster_with_columns_chain() -> Result<(), Box<dyn std::error::Error>> {621use polars_core::prelude::*;622623let df = df!("foo" => &[0.5, 1.7, 3.2],624"bar" => &[4.1, 1.5, 9.2])?;625626let df = df627.lazy()628.without_optimizations()629.with_cluster_with_columns(true)630.with_columns([col("foo").alias("foo1")])631.with_columns([col("foo").alias("foo2")])632.with_columns([col("foo").alias("foo3")])633.with_columns([col("foo").alias("foo4")]);634635let unoptimized = df.clone().to_alp().unwrap();636let optimized = df.to_alp_optimized().unwrap();637638let unoptimized = unoptimized.describe();639let optimized = optimized.describe();640641println!("\n---\n");642643println!("Unoptimized:\n{unoptimized}",);644println!("\n---\n");645println!("Optimized:\n{optimized}");646647assert_eq!(num_occurrences(&unoptimized, "WITH_COLUMNS"), 4);648assert_eq!(num_occurrences(&optimized, "WITH_COLUMNS"), 1);649650Ok(())651}652653654