Path: blob/main/crates/polars-sql/tests/functions_aggregate.rs
6939 views
use polars_core::prelude::*;1use polars_lazy::prelude::*;2use polars_plan::dsl::Expr;3use polars_sql::*;45fn create_df() -> LazyFrame {6df! {7"Data" => [1000, 2000, 3000, 4000, 5000, 6000]8}9.unwrap()10.lazy()11}1213fn create_expected(expr: Expr, sql: &str) -> (DataFrame, DataFrame) {14let df = create_df();15let alias = "TEST";1617let query = format!(18r#"19SELECT20{sql} as {alias}21FROM22df23"#24);2526let expected = df27.clone()28.select(&[expr.alias(alias)])29.sort([alias], Default::default())30.collect()31.unwrap();32let mut ctx = SQLContext::new();33ctx.register("df", df);3435let actual = ctx.execute(&query).unwrap().collect().unwrap();36(expected, actual)37}3839#[test]40fn test_median() {41let expr = col("Data").median();4243let sql_expr = "MEDIAN(Data)";44let (expected, actual) = create_expected(expr, sql_expr);4546assert!(expected.equals(&actual))47}4849#[test]50fn test_quantile_cont() {51for &q in &[0.25, 0.5, 0.75] {52let expr = col("Data").quantile(lit(q), QuantileMethod::Linear);5354let sql_expr = format!("QUANTILE_CONT(Data, {q})");55let (expected, actual) = create_expected(expr, &sql_expr);5657assert!(58expected.equals(&actual),59"q: {q}: expected {expected:?}, got {actual:?}"60)61}62}6364#[test]65fn test_quantile_disc() {66for &q in &[0.25, 0.5, 0.75] {67let expr = col("Data").quantile(lit(q), QuantileMethod::Equiprobable);6869let sql_expr = format!("QUANTILE_DISC(Data, {q})");70let (expected, actual) = create_expected(expr, &sql_expr);7172assert!(expected.equals(&actual))73}74}7576#[test]77fn test_quantile_out_of_range() {78for &q in &["-1", "2", "-0.01", "1.01"] {79for &func in &["QUANTILE_CONT", "QUANTILE_DISC"] {80let query = format!("SELECT {func}(Data, {q})");81let mut ctx = SQLContext::new();82ctx.register("df", create_df());83let actual = ctx.execute(&query);84assert!(actual.is_err())85}86}87}8889#[test]90fn test_quantile_disc_conformance() {91let expected = df![92"q" => [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0],93"Data" => [1000, 1000, 2000, 2000, 3000, 3000, 4000, 5000, 5000, 6000, 6000],94]95.unwrap();9697let mut ctx = SQLContext::new();98ctx.register("df", create_df());99100let mut actual: Option<DataFrame> = None;101for &q in &[0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0] {102let res = ctx103.execute(&format!(104"SELECT {q}::float as q, QUANTILE_DISC(Data, {q}) as Data FROM df"105))106.unwrap()107.collect()108.unwrap();109actual = if let Some(df) = actual {110Some(df.vstack(&res).unwrap())111} else {112Some(res)113};114}115116assert!(117expected.equals(actual.as_ref().unwrap()),118"expected {expected:?}, got {actual:?}"119)120}121122fn create_df_corr() -> LazyFrame {123df! {124"a" => [1, 2, 3, 4, 5, 6],125"b" => [2, 4, 10, 8, 9, 13],126"c" => ["a", "b", "a", "a", "b", "b"]127}128.unwrap()129.lazy()130}131132#[test]133fn test_corr() {134let df = create_df_corr();135136let expr_corr = pearson_corr(col("a"), col("b")).alias("corr");137let expr_cov = cov(col("a"), col("b"), 1).alias("cov");138let expr_cov_pop = cov(col("a"), col("b"), 0).alias("cov_pop");139let expected = df140.clone()141.select(&[expr_corr, expr_cov, expr_cov_pop])142.collect()143.unwrap();144145let mut ctx = SQLContext::new();146ctx.register("df", df);147let sql = r#"148SELECT149CORR(a, b) as corr,150COVAR(a, b) as covar,151COVAR_POP(a, b) as covar_pop152FROM df"#;153let actual = ctx.execute(sql).unwrap().collect().unwrap();154155assert_eq!(expected, actual, "expected {expected:?}, got {actual:?}");156}157158#[test]159fn test_corr_group_by() {160let df = create_df_corr();161162let expected = df163.clone()164.group_by(["c"])165.agg([166pearson_corr(col("a"), col("b")).alias("corr"),167cov(col("a"), col("b"), 1).alias("cov"),168])169.sort(["c"], Default::default())170.collect()171.unwrap();172173let mut ctx = SQLContext::new();174ctx.register("df", df);175let sql = r#"176SELECT177c,178CORR(a, b) AS corr,179COVAR(a, b) AS covar180FROM df181GROUP BY c182ORDER BY c"#;183let actual = ctx.execute(sql).unwrap().collect().unwrap();184185assert_eq!(expected, actual, "expected {expected:?}, got {actual:?}");186}187188189