Path: blob/main/crates/polars-lazy/src/tests/aggregations.rs
6939 views
use polars_ops::prelude::ListNameSpaceImpl;1use polars_utils::unitvec;23use super::*;45#[test]6#[cfg(feature = "dtype-datetime")]7fn test_agg_list_type() -> PolarsResult<()> {8let s = Series::new("foo".into(), &[1, 2, 3]);9let s = s.cast(&DataType::Datetime(TimeUnit::Nanoseconds, None))?;1011let l = unsafe { s.agg_list(&GroupsType::Idx(vec![(0, unitvec![0, 1, 2])].into())) };1213let result = match l.dtype() {14DataType::List(inner) => {15matches!(&**inner, DataType::Datetime(TimeUnit::Nanoseconds, None))16},17_ => false,18};19assert!(result);2021Ok(())22}2324#[test]25fn test_agg_exprs() -> PolarsResult<()> {26let df = fruits_cars();2728// a binary expression followed by a function and an aggregation. See if it runs29let out = df30.lazy()31.group_by_stable([col("cars")])32.agg([(lit(1) - col("A"))33.map(|s| Ok(&s * 2), |_, f| Ok(f.clone()))34.alias("foo")])35.collect()?;36let ca = out.column("foo")?.list()?;37let out = ca.lst_lengths();3839assert_eq!(Vec::from(&out), &[Some(4), Some(1)]);40Ok(())41}4243#[test]44fn test_agg_unique_first() -> PolarsResult<()> {45let df = df![46"g"=> [1, 1, 2, 2, 3, 4, 1],47"v"=> [1, 2, 2, 2, 3, 4, 1],48]?;4950let out = df51.lazy()52.group_by_stable([col("g")])53.agg([54col("v").unique().first().alias("v_first"),55col("v")56.unique()57.sort(Default::default())58.first()59.alias("true_first"),60col("v").unique().implode(),61])62.collect()?;6364let a = out.column("v_first").unwrap();65let a = a.as_materialized_series().sum::<i32>().unwrap();66// can be both because unique does not guarantee order67assert!(a == 10 || a == 11);6869let a = out.column("true_first").unwrap();70let a = a.as_materialized_series().sum::<i32>().unwrap();71// can be both because unique does not guarantee order72assert_eq!(a, 10);7374Ok(())75}7677#[test]78#[cfg(feature = "cum_agg")]79fn test_cum_sum_agg_as_key() -> PolarsResult<()> {80let df = df![81"depth" => &[0i32, 1, 2, 3, 4, 5, 6, 7, 8, 9],82"soil" => &["peat", "peat", "peat", "silt", "silt", "silt", "sand", "sand", "peat", "peat"]83]?;84// this checks if the grouper can work with the complex query as a key8586let out = df87.lazy()88.group_by([col("soil")89.neq(col("soil").shift_and_fill(lit(1), col("soil").first()))90.cum_sum(false)91.alias("key")])92.agg([col("depth").max().name().keep()])93.sort(["depth"], Default::default())94.collect()?;9596assert_eq!(97Vec::from(out.column("key")?.u32()?),98&[Some(0), Some(1), Some(2), Some(3)]99);100assert_eq!(101Vec::from(out.column("depth")?.i32()?),102&[Some(2), Some(5), Some(7), Some(9)]103);104105Ok(())106}107108#[test]109#[cfg(feature = "moment")]110fn test_auto_skew_kurtosis_agg() -> PolarsResult<()> {111let df = fruits_cars();112113let out = df114.lazy()115.group_by([col("fruits")])116.agg([117col("B").skew(false).alias("bskew"),118col("B").kurtosis(false, false).alias("bkurt"),119])120.collect()?;121122assert!(matches!(out.column("bskew")?.dtype(), DataType::Float64));123assert!(matches!(out.column("bkurt")?.dtype(), DataType::Float64));124125Ok(())126}127128#[test]129fn test_auto_list_agg() -> PolarsResult<()> {130let df = fruits_cars();131132// test if alias executor adds a list after shift and fill133let out = df134.clone()135.lazy()136.group_by([col("fruits")])137.agg([col("B").shift_and_fill(lit(-1), lit(-1)).alias("foo")])138.collect()?;139140assert!(matches!(out.column("foo")?.dtype(), DataType::List(_)));141142// test if it runs and group_by executor thus implements a list after shift_and_fill143let _out = df144.clone()145.lazy()146.group_by([col("fruits")])147.agg([col("B").shift_and_fill(lit(-1), lit(-1))])148.collect()?;149150// test if window expr executor adds list151let _out = df152.clone()153.lazy()154.select([col("B").shift_and_fill(lit(-1), lit(-1)).alias("foo")])155.collect()?;156157let _out = df158.lazy()159.select([col("B").shift_and_fill(lit(-1), lit(-1))])160.collect()?;161Ok(())162}163#[test]164#[cfg(feature = "rolling_window")]165fn test_power_in_agg_list1() -> PolarsResult<()> {166let df = fruits_cars();167168// this test if the group tuples are correctly updated after169// a flat apply on a final aggregation170let out = df171.lazy()172.group_by([col("fruits")])173.agg([174col("A")175.rolling_min(RollingOptionsFixedWindow {176window_size: 1,177..Default::default()178})179.alias("input"),180col("A")181.rolling_min(RollingOptionsFixedWindow {182window_size: 1,183..Default::default()184})185.pow(2.0)186.alias("foo"),187])188.sort(189["fruits"],190SortMultipleOptions::default().with_order_descending(true),191)192.collect()?;193194let agg = out.column("foo")?.list()?;195let first = agg.get_as_series(0).unwrap();196let vals = first.f64()?;197assert_eq!(Vec::from(vals), &[Some(1.0), Some(4.0), Some(25.0)]);198199Ok(())200}201202#[test]203#[cfg(feature = "rolling_window")]204fn test_power_in_agg_list2() -> PolarsResult<()> {205let df = fruits_cars();206207// this test if the group tuples are correctly updated after208// a flat apply on evaluate_on_groups209let out = df210.lazy()211.group_by([col("fruits")])212.agg([col("A")213.rolling_min(RollingOptionsFixedWindow {214window_size: 2,215min_periods: 2,216..Default::default()217})218.pow(2.0)219.sum()220.alias("foo")])221.sort(222["fruits"],223SortMultipleOptions::default().with_order_descending(true),224)225.collect()?;226227let agg = out.column("foo")?.f64()?;228assert_eq!(Vec::from(agg), &[Some(5.0), Some(9.0)]);229230Ok(())231}232#[test]233fn test_binary_agg_context_0() -> PolarsResult<()> {234let df = df![235"groups" => [1, 1, 2, 2, 3, 3],236"vals" => [1, 2, 3, 4, 5, 6]237]238.unwrap();239240let out = df241.lazy()242.group_by_stable([col("groups")])243.agg([when(col("vals").first().neq(lit(1)))244.then(repeat(lit("a"), len()))245.otherwise(repeat(lit("b"), len()))246.alias("foo")])247.collect()248.unwrap();249250let out = out.column("foo")?;251let out = out.explode(false)?;252let out = out.str()?;253assert_eq!(254Vec::from(out),255&[256Some("b"),257Some("b"),258Some("a"),259Some("a"),260Some("a"),261Some("a")262]263);264Ok(())265}266267// just like binary expression, this must be changed. This can work268#[test]269fn test_binary_agg_context_1() -> PolarsResult<()> {270let df = df![271"groups" => [1, 1, 2, 2, 3, 3],272"vals" => [1, 13, 3, 87, 1, 6]273]?;274275// groups276// 1 => [1, 13]277// 2 => [3, 87]278// 3 => [1, 6]279280let out = df281.clone()282.lazy()283.group_by_stable([col("groups")])284.agg([when(col("vals").eq(lit(1)))285.then(col("vals").sum())286.otherwise(lit(90))287.alias("vals")])288.collect()?;289290// if vals == 1 then sum(vals) else vals291// [14, 90]292// [90, 90]293// [7, 90]294let out = out.column("vals")?;295let out = out.explode(false)?;296let out = out.i32()?;297assert_eq!(298Vec::from(out),299&[Some(14), Some(90), Some(90), Some(90), Some(7), Some(90)]300);301302let out = df303.lazy()304.group_by_stable([col("groups")])305.agg([when(col("vals").eq(lit(1)))306.then(lit(90))307.otherwise(col("vals").sum())308.alias("vals")])309.collect()?;310311// if vals == 1 then 90 else sum(vals)312// [90, 14]313// [90, 90]314// [90, 7]315let out = out.column("vals")?;316let out = out.explode(false)?;317let out = out.i32()?;318assert_eq!(319Vec::from(out),320&[Some(90), Some(14), Some(90), Some(90), Some(90), Some(7)]321);322323Ok(())324}325326#[test]327fn test_binary_agg_context_2() -> PolarsResult<()> {328let df = df![329"groups" => [1, 1, 2, 2, 3, 3],330"vals" => [1, 2, 3, 4, 5, 6]331]?;332333// this is complex because we first aggregate one expression of the binary operation.334335let out = df336.clone()337.lazy()338.group_by_stable([col("groups")])339.agg([(col("vals").first() - col("vals")).alias("vals")])340.collect()?;341342// 0 - [1, 2] = [0, -1]343// 3 - [3, 4] = [0, -1]344// 5 - [5, 6] = [0, -1]345let out = out.column("vals")?;346let out = out.explode(false)?;347let out = out.i32()?;348assert_eq!(349Vec::from(out),350&[Some(0), Some(-1), Some(0), Some(-1), Some(0), Some(-1)]351);352353// Same, but now we reverse the lhs / rhs.354let out = df355.lazy()356.group_by_stable([col("groups")])357.agg([((col("vals")) - col("vals").first()).alias("vals")])358.collect()?;359360// [1, 2] - 1 = [0, 1]361// [3, 4] - 3 = [0, 1]362// [5, 6] - 5 = [0, 1]363let out = out.column("vals")?;364let out = out.explode(false)?;365let out = out.i32()?;366assert_eq!(367Vec::from(out),368&[Some(0), Some(1), Some(0), Some(1), Some(0), Some(1)]369);370371Ok(())372}373374#[test]375fn test_binary_agg_context_3() -> PolarsResult<()> {376let df = fruits_cars();377378let out = df379.lazy()380.group_by_stable([col("cars")])381.agg([(col("A") - col("A").first()).last().alias("last")])382.collect()?;383384let out = out.column("last")?;385assert_eq!(out.get(0)?, AnyValue::Int32(4));386assert_eq!(out.get(1)?, AnyValue::Int32(0));387388Ok(())389}390391#[test]392fn test_shift_elementwise_issue_2509() -> PolarsResult<()> {393let df = df![394"x"=> [0, 0, 0, 1, 1, 1, 2, 2, 2],395"y"=> [0, 10, 20, 0, 10, 20, 0, 10, 20]396]?;397let out = df398.lazy()399// Don't use maintain order here! That hides the bug400.group_by([col("x")])401.agg(&[(col("y").shift(lit(-1)) + col("x")).alias("sum")])402.sort(["x"], Default::default())403.collect()?;404405let out = out.explode(["sum"])?;406let out = out.column("sum")?;407assert_eq!(out.get(0)?, AnyValue::Int32(10));408assert_eq!(out.get(1)?, AnyValue::Int32(20));409assert_eq!(out.get(2)?, AnyValue::Null);410assert_eq!(out.get(3)?, AnyValue::Int32(11));411assert_eq!(out.get(4)?, AnyValue::Int32(21));412assert_eq!(out.get(5)?, AnyValue::Null);413414Ok(())415}416417#[test]418fn take_aggregations() -> PolarsResult<()> {419let df = df![420"user" => ["lucy", "bob", "bob", "lucy", "tim"],421"book" => ["c", "b", "a", "a", "a"],422"count" => [3, 1, 2, 1, 1]423]?;424425let out = df426.clone()427.lazy()428.group_by([col("user")])429.agg([col("book").get(col("count").arg_max()).alias("fav_book")])430.sort(["user"], Default::default())431.collect()?;432433let s = out.column("fav_book")?;434assert_eq!(s.get(0)?, AnyValue::String("a"));435assert_eq!(s.get(1)?, AnyValue::String("c"));436assert_eq!(s.get(2)?, AnyValue::String("a"));437438let out = df439.clone()440.lazy()441.group_by([col("user")])442.agg([443// keep the head as it test slice correctness444col("book")445.gather(col("count").arg_sort(true, false).head(Some(2)))446.alias("ordered"),447])448.sort(["user"], Default::default())449.collect()?;450let s = out.column("ordered")?;451let flat = s.explode(false)?;452let flat = flat.str()?;453let vals = flat.into_no_null_iter().collect::<Vec<_>>();454assert_eq!(vals, ["a", "b", "c", "a", "a"]);455456let out = df457.lazy()458.group_by([col("user")])459.agg([col("book").get(lit(0)).alias("take_lit")])460.sort(["user"], Default::default())461.collect()?;462463let taken = out.column("take_lit")?;464let taken = taken.str()?;465let vals = taken.into_no_null_iter().collect::<Vec<_>>();466assert_eq!(vals, ["b", "c", "a"]);467468Ok(())469}470#[test]471fn test_take_consistency() -> PolarsResult<()> {472let df = fruits_cars();473let out = df474.clone()475.lazy()476.select([col("A").arg_sort(true, false).get(lit(0))])477.collect()?;478479let a = out.column("A")?;480let a = a.idx()?;481assert_eq!(a.get(0), Some(4));482483let out = df484.clone()485.lazy()486.group_by_stable([col("cars")])487.agg([col("A").arg_sort(true, false).get(lit(0))])488.collect()?;489490let out = out.column("A")?;491let out = out.idx()?;492assert_eq!(Vec::from(out), &[Some(3), Some(0)]);493494let out_df = df495.lazy()496.group_by_stable([col("cars")])497.agg([498col("A"),499col("A").arg_sort(true, false).get(lit(0)).alias("1"),500col("A")501.get(col("A").arg_sort(true, false).get(lit(0)))502.alias("2"),503])504.collect()?;505506let out = out_df.column("2")?;507let out = out.i32()?;508assert_eq!(Vec::from(out), &[Some(5), Some(2)]);509510let out = out_df.column("1")?;511let out = out.idx()?;512assert_eq!(Vec::from(out), &[Some(3), Some(0)]);513514Ok(())515}516517#[test]518fn test_take_in_groups() -> PolarsResult<()> {519let df = fruits_cars();520521let out = df522.lazy()523.sort(["fruits"], Default::default())524.select([col("B").get(lit(0u32)).over([col("fruits")]).alias("taken")])525.collect()?;526527assert_eq!(528Vec::from(out.column("taken")?.i32()?),529&[Some(3), Some(3), Some(5), Some(5), Some(5)]530);531Ok(())532}533534#[test]535fn test_anonymous_function_returns_scalar_all_null_20679() {536use std::sync::Arc;537538fn reduction_function(column: Column) -> PolarsResult<Column> {539let val = column.get(0)?.into_static();540let col = Column::new_scalar("".into(), Scalar::new(column.dtype().clone(), val), 1);541Ok(col)542}543544let a = Column::new("a".into(), &[0, 0, 1]);545let dtype = DataType::Null;546let b = Column::new_scalar("b".into(), Scalar::new(dtype, AnyValue::Null), 3);547let df = DataFrame::new(vec![a, b]).unwrap();548549let f = move |c: &mut [Column]| reduction_function(std::mem::take(&mut c[0]));550let dt = |_: &Schema, fs: &[Field]| Ok(fs[0].clone());551552let f = BaseColumnUdf::new(f, dt);553554let expr = Expr::AnonymousFunction {555input: vec![col("b")],556function: LazySerde::Deserialized(SpecialEq::new(Arc::new(f))),557options: FunctionOptions::aggregation(),558fmt_str: Box::new(PlSmallStr::EMPTY),559};560561let grouped_df = df562.lazy()563.group_by([col("a")])564.agg([expr])565.collect()566.unwrap();567568assert_eq!(grouped_df.get_columns()[1].dtype(), &DataType::Null);569}570571572