Path: blob/main/py-polars/tests/unit/sql/test_numeric.py
6939 views
from __future__ import annotations12from decimal import Decimal as D3from typing import TYPE_CHECKING45import pytest67import polars as pl8from polars.exceptions import SQLInterfaceError, SQLSyntaxError9from polars.testing import assert_frame_equal, assert_series_equal1011if TYPE_CHECKING:12from polars._typing import PolarsDataType131415def test_div() -> None:16res = pl.sql(17"""18SELECT label, DIV(a, b) AS a_div_b, DIV(tbl.b, tbl.a) AS b_div_a19FROM (20VALUES21('a', 20.5, 6),22('b', NULL, 12),23('c', 10.0, 24),24('d', 5.0, NULL),25('e', 2.5, 5)26) AS tbl(label, a, b)27"""28).collect()2930assert res.to_dict(as_series=False) == {31"label": ["a", "b", "c", "d", "e"],32"a_div_b": [3, None, 0, None, 0],33"b_div_a": [0, None, 2, None, 2],34}353637def test_modulo() -> None:38df = pl.DataFrame(39{40"a": [1.5, None, 3.0, 13 / 3, 5.0],41"b": [6, 7, 8, 9, 10],42"c": [11, 12, 13, 14, 15],43"d": [16.5, 17.0, 18.5, None, 20.0],44}45)46out = df.sql(47"""48SELECT49a % 2 AS a2,50b % 3 AS b3,51MOD(c, 4) AS c4,52MOD(d, 5.5) AS d5553FROM self54"""55)56assert_frame_equal(57out,58pl.DataFrame(59{60"a2": [1.5, None, 1.0, 1 / 3, 1.0],61"b3": [0, 1, 2, 0, 1],62"c4": [3, 0, 1, 2, 3],63"d55": [0.0, 0.5, 2.0, None, 3.5],64}65),66)676869@pytest.mark.parametrize(70("value", "sqltype", "prec_scale", "expected_value", "expected_dtype"),71[72(64.5, "numeric", "(3,1)", D("64.5"), pl.Decimal(3, 1)),73(512.5, "decimal", "(4,1)", D("512.5"), pl.Decimal(4, 1)),74(512.5, "numeric", "(4,0)", D("512"), pl.Decimal(4, 0)),75(-1024.75, "decimal", "(10,0)", D("-1024"), pl.Decimal(10, 0)),76(-1024.75, "numeric", "(10)", D("-1024"), pl.Decimal(10, 0)),77(-1024.75, "dec", "", D("-1024.75"), pl.Decimal(38, 9)),78],79)80def test_numeric_decimal_type(81value: float,82sqltype: str,83prec_scale: str,84expected_value: D,85expected_dtype: PolarsDataType,86) -> None:87df = pl.DataFrame({"n": [value]})88with pl.SQLContext(df=df) as ctx:89result = ctx.execute(90f"""91SELECT n::{sqltype}{prec_scale} AS "dec" FROM df92"""93)94expected = pl.LazyFrame(95data={"dec": [expected_value]},96schema={"dec": expected_dtype},97)98assert_frame_equal(result, expected)99100101@pytest.mark.parametrize(102("decimals", "expected"),103[104(0, [-8192.0, -4.0, -2.0, 2.0, 4.0, 8193.0]),105(1, [-8192.5, -4.0, -1.5, 2.5, 3.6, 8192.5]),106(2, [-8192.5, -3.96, -1.54, 2.46, 3.6, 8192.5]),107(3, [-8192.499, -3.955, -1.543, 2.457, 3.599, 8192.5]),108(4, [-8192.499, -3.955, -1.5432, 2.4568, 3.599, 8192.5001]),109],110)111def test_round_ndigits(decimals: int, expected: list[float]) -> None:112df = pl.DataFrame(113{"n": [-8192.499, -3.9550, -1.54321, 2.45678, 3.59901, 8192.5001]},114)115with pl.SQLContext(df=df, eager=True) as ctx:116if decimals == 0:117out = ctx.execute("SELECT ROUND(n) AS n FROM df")118assert_series_equal(out["n"], pl.Series("n", values=expected))119120out = ctx.execute(f'SELECT ROUND("n",{decimals}) AS n FROM df')121assert_series_equal(out["n"], pl.Series("n", values=expected))122123124def test_round_ndigits_errors() -> None:125df = pl.DataFrame({"n": [99.999]})126with pl.SQLContext(df=df, eager=True) as ctx:127with pytest.raises(128SQLSyntaxError, match=r"invalid value for ROUND decimals \('!!'\)"129):130ctx.execute("SELECT ROUND(n,'!!') AS n FROM df")131132with pytest.raises(133SQLInterfaceError, match=r"ROUND .* negative decimals value \(-1\)"134):135ctx.execute("SELECT ROUND(n,-1) AS n FROM df")136137with pytest.raises(138SQLSyntaxError, match=r"ROUND expects 1-2 arguments \(found 4\)"139):140ctx.execute("SELECT ROUND(1.2345,6,7,8) AS n FROM df")141142143def test_stddev_variance() -> None:144df = pl.DataFrame(145{146"v1": [-1.0, 0.0, 1.0],147"v2": [5.5, 0.0, 3.0],148"v3": [-10, None, 10],149"v4": [-100.0, 0.0, -50.0],150}151)152with pl.SQLContext(df=df) as ctx:153# note: we support all common aliases for std/var154out = ctx.execute(155"""156SELECT157STDEV(v1) AS "v1_std",158STDDEV(v2) AS "v2_std",159STDEV_SAMP(v3) AS "v3_std",160STDDEV_SAMP(v4) AS "v4_std",161VAR(v1) AS "v1_var",162VARIANCE(v2) AS "v2_var",163VARIANCE(v3) AS "v3_var",164VAR_SAMP(v4) AS "v4_var"165FROM df166"""167).collect()168169assert_frame_equal(170out,171pl.DataFrame(172{173"v1_std": [1.0],174"v2_std": [2.7537852736431],175"v3_std": [14.142135623731],176"v4_std": [50.0],177"v1_var": [1.0],178"v2_var": [7.5833333333333],179"v3_var": [200.0],180"v4_var": [2500.0],181}182),183)184185186