Path: blob/main/crates/polars-lazy/src/tests/aggregations.rs
8448 views
use polars_ops::prelude::ListNameSpaceImpl;1use polars_utils::unitvec;23use super::*;45#[test]6#[cfg(feature = "dtype-datetime")]7fn test_agg_list_type() -> PolarsResult<()> {8let s = Series::new("foo".into(), &[1, 2, 3]);9let s = s.cast(&DataType::Datetime(TimeUnit::Nanoseconds, None))?;1011let l = unsafe { s.agg_list(&GroupsType::Idx(vec![(0, unitvec![0, 1, 2])].into())) };1213let result = match l.dtype() {14DataType::List(inner) => {15matches!(&**inner, DataType::Datetime(TimeUnit::Nanoseconds, None))16},17_ => false,18};19assert!(result);2021Ok(())22}2324#[test]25fn test_agg_exprs() -> PolarsResult<()> {26let df = fruits_cars();2728// a binary expression followed by a function and an aggregation. See if it runs29let out = df30.lazy()31.group_by_stable([col("cars")])32.agg([(lit(1) - col("A"))33.map(|s| Ok(&s * 2), |_, f| Ok(f.clone()))34.alias("foo")])35.collect()?;36let ca = out.column("foo")?.list()?;37let out = ca.lst_lengths();3839assert_eq!(Vec::from(&out), &[Some(4), Some(1)]);40Ok(())41}4243#[test]44fn test_agg_unique_first() -> PolarsResult<()> {45let df = df![46"g"=> [1, 1, 2, 2, 3, 4, 1],47"v"=> [1, 2, 2, 2, 3, 4, 1],48]?;4950let out = df51.lazy()52.group_by_stable([col("g")])53.agg([54col("v").unique().first().alias("v_first"),55col("v")56.unique()57.sort(Default::default())58.first()59.alias("true_first"),60col("v").unique().implode(),61])62.collect()?;6364let a = out.column("v_first").unwrap();65let a = a.as_materialized_series().sum::<i32>().unwrap();66// can be both because unique does not guarantee order67assert!(a == 10 || a == 11);6869let a = out.column("true_first").unwrap();70let a = a.as_materialized_series().sum::<i32>().unwrap();71// can be both because unique does not guarantee order72assert_eq!(a, 10);7374Ok(())75}7677#[test]78#[cfg(feature = "cum_agg")]79fn test_cum_sum_agg_as_key() -> PolarsResult<()> {80let df = df![81"depth" => &[0i32, 1, 2, 3, 4, 5, 6, 7, 8, 9],82"soil" => &["peat", "peat", "peat", "silt", "silt", "silt", "sand", "sand", "peat", "peat"]83]?;84// this checks if the grouper can work with the complex query as a key8586let out = df87.lazy()88.group_by([col("soil")89.neq(col("soil").shift_and_fill(lit(1), col("soil").first()))90.cum_sum(false)91.alias("key")])92.agg([col("depth").max().name().keep()])93.sort(["depth"], Default::default())94.collect()?;9596assert_eq!(97Vec::from(out.column("key")?.u32()?),98&[Some(0), Some(1), Some(2), Some(3)]99);100assert_eq!(101Vec::from(out.column("depth")?.i32()?),102&[Some(2), Some(5), Some(7), Some(9)]103);104105Ok(())106}107108#[test]109#[cfg(feature = "moment")]110fn test_auto_skew_kurtosis_agg() -> PolarsResult<()> {111let df = fruits_cars();112113let out = df114.lazy()115.group_by([col("fruits")])116.agg([117col("B").skew(false).alias("bskew"),118col("B").kurtosis(false, false).alias("bkurt"),119])120.collect()?;121122assert!(matches!(out.column("bskew")?.dtype(), DataType::Float64));123assert!(matches!(out.column("bkurt")?.dtype(), DataType::Float64));124125Ok(())126}127128#[test]129fn test_auto_list_agg() -> PolarsResult<()> {130let df = fruits_cars();131132// test if alias executor adds a list after shift and fill133let out = df134.clone()135.lazy()136.group_by([col("fruits")])137.agg([col("B").shift_and_fill(lit(-1), lit(-1)).alias("foo")])138.collect()?;139140assert!(matches!(out.column("foo")?.dtype(), DataType::List(_)));141142// test if it runs and group_by executor thus implements a list after shift_and_fill143let _out = df144.clone()145.lazy()146.group_by([col("fruits")])147.agg([col("B").shift_and_fill(lit(-1), lit(-1))])148.collect()?;149150// test if window expr executor adds list151let _out = df152.clone()153.lazy()154.select([col("B").shift_and_fill(lit(-1), lit(-1)).alias("foo")])155.collect()?;156157let _out = df158.lazy()159.select([col("B").shift_and_fill(lit(-1), lit(-1))])160.collect()?;161Ok(())162}163#[test]164#[cfg(feature = "rolling_window")]165fn test_power_in_agg_list1() -> PolarsResult<()> {166let df = fruits_cars();167168// this test if the group tuples are correctly updated after169// a flat apply on a final aggregation170let out = df171.lazy()172.group_by([col("fruits")])173.agg([174col("A")175.rolling_min(RollingOptionsFixedWindow {176window_size: 1,177..Default::default()178})179.alias("input"),180col("A")181.rolling_min(RollingOptionsFixedWindow {182window_size: 1,183..Default::default()184})185.pow(2.0)186.alias("foo"),187])188.sort(189["fruits"],190SortMultipleOptions::default().with_order_descending(true),191)192.collect()?;193194let agg = out.column("foo")?.list()?;195let first = agg.get_as_series(0).unwrap();196let vals = first.f64()?;197assert_eq!(Vec::from(vals), &[Some(1.0), Some(4.0), Some(25.0)]);198199Ok(())200}201202#[test]203#[cfg(feature = "rolling_window")]204fn test_power_in_agg_list2() -> PolarsResult<()> {205let df = fruits_cars();206207// this test if the group tuples are correctly updated after208// a flat apply on evaluate_on_groups209let out = df210.lazy()211.group_by([col("fruits")])212.agg([col("A")213.rolling_min(RollingOptionsFixedWindow {214window_size: 2,215min_periods: 2,216..Default::default()217})218.pow(2.0)219.sum()220.alias("foo")])221.sort(222["fruits"],223SortMultipleOptions::default().with_order_descending(true),224)225.collect()?;226227let agg = out.column("foo")?.f64()?;228assert_eq!(Vec::from(agg), &[Some(5.0), Some(9.0)]);229230Ok(())231}232#[test]233fn test_binary_agg_context_0() -> PolarsResult<()> {234let df = df![235"groups" => [1, 1, 2, 2, 3, 3],236"vals" => [1, 2, 3, 4, 5, 6]237]238.unwrap();239240let out = df241.lazy()242.group_by_stable([col("groups")])243.agg([when(col("vals").first().neq(lit(1)))244.then(repeat(lit("a"), len()))245.otherwise(repeat(lit("b"), len()))246.alias("foo")])247.collect()248.unwrap();249250let out = out.column("foo")?;251let out = out.explode(ExplodeOptions {252empty_as_null: true,253keep_nulls: true,254})?;255let out = out.str()?;256assert_eq!(257Vec::from(out),258&[259Some("b"),260Some("b"),261Some("a"),262Some("a"),263Some("a"),264Some("a")265]266);267Ok(())268}269270// just like binary expression, this must be changed. This can work271#[test]272fn test_binary_agg_context_1() -> PolarsResult<()> {273let df = df![274"groups" => [1, 1, 2, 2, 3, 3],275"vals" => [1, 13, 3, 87, 1, 6]276]?;277278// groups279// 1 => [1, 13]280// 2 => [3, 87]281// 3 => [1, 6]282283let out = df284.clone()285.lazy()286.group_by_stable([col("groups")])287.agg([when(col("vals").eq(lit(1)))288.then(col("vals").sum())289.otherwise(lit(90))290.alias("vals")])291.collect()?;292293// if vals == 1 then sum(vals) else vals294// [14, 90]295// [90, 90]296// [7, 90]297let out = out.column("vals")?;298let out = out.explode(ExplodeOptions {299empty_as_null: true,300keep_nulls: true,301})?;302let out = out.i32()?;303assert_eq!(304Vec::from(out),305&[Some(14), Some(90), Some(90), Some(90), Some(7), Some(90)]306);307308let out = df309.lazy()310.group_by_stable([col("groups")])311.agg([when(col("vals").eq(lit(1)))312.then(lit(90))313.otherwise(col("vals").sum())314.alias("vals")])315.collect()?;316317// if vals == 1 then 90 else sum(vals)318// [90, 14]319// [90, 90]320// [90, 7]321let out = out.column("vals")?;322let out = out.explode(ExplodeOptions {323empty_as_null: true,324keep_nulls: true,325})?;326let out = out.i32()?;327assert_eq!(328Vec::from(out),329&[Some(90), Some(14), Some(90), Some(90), Some(90), Some(7)]330);331332Ok(())333}334335#[test]336fn test_binary_agg_context_2() -> PolarsResult<()> {337let df = df![338"groups" => [1, 1, 2, 2, 3, 3],339"vals" => [1, 2, 3, 4, 5, 6]340]?;341342// this is complex because we first aggregate one expression of the binary operation.343344let out = df345.clone()346.lazy()347.group_by_stable([col("groups")])348.agg([(col("vals").first() - col("vals")).alias("vals")])349.collect()?;350351// 0 - [1, 2] = [0, -1]352// 3 - [3, 4] = [0, -1]353// 5 - [5, 6] = [0, -1]354let out = out.column("vals")?;355let out = out.explode(ExplodeOptions {356empty_as_null: true,357keep_nulls: true,358})?;359let out = out.i32()?;360assert_eq!(361Vec::from(out),362&[Some(0), Some(-1), Some(0), Some(-1), Some(0), Some(-1)]363);364365// Same, but now we reverse the lhs / rhs.366let out = df367.lazy()368.group_by_stable([col("groups")])369.agg([((col("vals")) - col("vals").first()).alias("vals")])370.collect()?;371372// [1, 2] - 1 = [0, 1]373// [3, 4] - 3 = [0, 1]374// [5, 6] - 5 = [0, 1]375let out = out.column("vals")?;376let out = out.explode(ExplodeOptions {377empty_as_null: true,378keep_nulls: true,379})?;380let out = out.i32()?;381assert_eq!(382Vec::from(out),383&[Some(0), Some(1), Some(0), Some(1), Some(0), Some(1)]384);385386Ok(())387}388389#[test]390fn test_binary_agg_context_3() -> PolarsResult<()> {391let df = fruits_cars();392393let out = df394.lazy()395.group_by_stable([col("cars")])396.agg([(col("A") - col("A").first()).last().alias("last")])397.collect()?;398399let out = out.column("last")?;400assert_eq!(out.get(0)?, AnyValue::Int32(4));401assert_eq!(out.get(1)?, AnyValue::Int32(0));402403Ok(())404}405406#[test]407fn test_shift_elementwise_issue_2509() -> PolarsResult<()> {408let df = df![409"x"=> [0, 0, 0, 1, 1, 1, 2, 2, 2],410"y"=> [0, 10, 20, 0, 10, 20, 0, 10, 20]411]?;412let out = df413.lazy()414// Don't use maintain order here! That hides the bug415.group_by([col("x")])416.agg(&[(col("y").shift(lit(-1)) + col("x")).alias("sum")])417.sort(["x"], Default::default())418.collect()?;419420let out = out.explode(421["sum"],422ExplodeOptions {423empty_as_null: true,424keep_nulls: true,425},426)?;427let out = out.column("sum")?;428assert_eq!(out.get(0)?, AnyValue::Int32(10));429assert_eq!(out.get(1)?, AnyValue::Int32(20));430assert_eq!(out.get(2)?, AnyValue::Null);431assert_eq!(out.get(3)?, AnyValue::Int32(11));432assert_eq!(out.get(4)?, AnyValue::Int32(21));433assert_eq!(out.get(5)?, AnyValue::Null);434435Ok(())436}437438#[test]439fn take_aggregations() -> PolarsResult<()> {440let df = df![441"user" => ["lucy", "bob", "bob", "lucy", "tim"],442"book" => ["c", "b", "a", "a", "a"],443"count" => [3, 1, 2, 1, 1]444]?;445446let out = df447.clone()448.lazy()449.group_by([col("user")])450.agg([col("book")451.get(col("count").arg_max(), false)452.alias("fav_book")])453.sort(["user"], Default::default())454.collect()?;455456let s = out.column("fav_book")?;457assert_eq!(s.get(0)?, AnyValue::String("a"));458assert_eq!(s.get(1)?, AnyValue::String("c"));459assert_eq!(s.get(2)?, AnyValue::String("a"));460461let out = df462.clone()463.lazy()464.group_by([col("user")])465.agg([466// keep the head as it test slice correctness467col("book")468.gather(col("count").arg_sort(true, false).head(Some(2)))469.alias("ordered"),470])471.sort(["user"], Default::default())472.collect()?;473let s = out.column("ordered")?;474let flat = s.explode(ExplodeOptions {475empty_as_null: true,476keep_nulls: true,477})?;478let flat = flat.str()?;479let vals = flat.into_no_null_iter().collect::<Vec<_>>();480assert_eq!(vals, ["a", "b", "c", "a", "a"]);481482let out = df483.lazy()484.group_by([col("user")])485.agg([col("book").get(lit(0), false).alias("take_lit")])486.sort(["user"], Default::default())487.collect()?;488489let taken = out.column("take_lit")?;490let taken = taken.str()?;491let vals = taken.into_no_null_iter().collect::<Vec<_>>();492assert_eq!(vals, ["b", "c", "a"]);493494Ok(())495}496#[test]497fn test_take_consistency() -> PolarsResult<()> {498let df = fruits_cars();499let out = df500.clone()501.lazy()502.select([col("A").arg_sort(true, false).get(lit(0), false)])503.collect()?;504505let a = out.column("A")?;506let a = a.idx()?;507assert_eq!(a.get(0), Some(4));508509let out = df510.clone()511.lazy()512.group_by_stable([col("cars")])513.agg([col("A").arg_sort(true, false).get(lit(0), false)])514.collect()?;515516let out = out.column("A")?;517let out = out.idx()?;518assert_eq!(Vec::from(out), &[Some(3), Some(0)]);519520let out_df = df521.lazy()522.group_by_stable([col("cars")])523.agg([524col("A"),525col("A").arg_sort(true, false).get(lit(0), false).alias("1"),526col("A")527.get(col("A").arg_sort(true, false).get(lit(0), false), false)528.alias("2"),529])530.collect()?;531532let out = out_df.column("2")?;533let out = out.i32()?;534assert_eq!(Vec::from(out), &[Some(5), Some(2)]);535536let out = out_df.column("1")?;537let out = out.idx()?;538assert_eq!(Vec::from(out), &[Some(3), Some(0)]);539540Ok(())541}542543#[test]544fn test_take_in_groups() -> PolarsResult<()> {545let df = fruits_cars();546547let out = df548.lazy()549.sort(["fruits"], Default::default())550.select([col("B")551.get(lit(0u32), false)552.over([col("fruits")])553.alias("taken")])554.collect()?;555556assert_eq!(557Vec::from(out.column("taken")?.i32()?),558&[Some(3), Some(3), Some(5), Some(5), Some(5)]559);560Ok(())561}562563#[test]564fn test_anonymous_function_returns_scalar_all_null_20679() {565use std::sync::Arc;566567fn reduction_function(column: Column) -> PolarsResult<Column> {568let val = column.get(0)?.into_static();569let col = Column::new_scalar("".into(), Scalar::new(column.dtype().clone(), val), 1);570Ok(col)571}572573let a = Column::new("a".into(), &[0, 0, 1]);574let dtype = DataType::Null;575let b = Column::new_scalar("b".into(), Scalar::new(dtype, AnyValue::Null), 3);576let df = DataFrame::new_infer_height(vec![a, b]).unwrap();577578let f = move |c: &mut [Column]| reduction_function(std::mem::take(&mut c[0]));579let dt = |_: &Schema, fs: &[Field]| Ok(fs[0].clone());580581let f = BaseColumnUdf::new(f, dt);582583let expr = Expr::AnonymousFunction {584input: vec![col("b")],585function: LazySerde::Deserialized(SpecialEq::new(Arc::new(f))),586options: FunctionOptions::aggregation(),587fmt_str: Box::new(PlSmallStr::EMPTY),588};589590let grouped_df = df591.lazy()592.group_by([col("a")])593.agg([expr])594.collect()595.unwrap();596597assert_eq!(grouped_df.columns()[1].dtype(), &DataType::Null);598}599600601