Path: blob/main/py-polars/tests/unit/sql/test_numeric.py
8354 views
from __future__ import annotations12from decimal import Decimal as D3from typing import TYPE_CHECKING45import pytest67import polars as pl8from polars.exceptions import SQLInterfaceError, SQLSyntaxError9from polars.testing import assert_frame_equal, assert_series_equal1011if TYPE_CHECKING:12from polars._typing import PolarsDataType131415def test_div() -> None:16res = pl.sql(17"""18SELECT label, DIV(a, b) AS a_div_b, DIV(tbl.b, tbl.a) AS b_div_a19FROM (20VALUES21('a', 20.5, 6),22('b', NULL, 12),23('c', 10.0, 24),24('d', 5.0, NULL),25('e', 2.5, 5)26) AS tbl(label, a, b)27"""28).collect()2930assert res.to_dict(as_series=False) == {31"label": ["a", "b", "c", "d", "e"],32"a_div_b": [3, None, 0, None, 0],33"b_div_a": [0, None, 2, None, 2],34}353637def test_modulo() -> None:38df = pl.DataFrame(39{40"a": [1.5, None, 3.0, 13 / 3, 5.0],41"b": [6, 7, 8, 9, 10],42"c": [11, 12, 13, 14, 15],43"d": [16.5, 17.0, 18.5, None, 20.0],44}45)46out = df.sql(47"""48SELECT49ROW_NUMBER() AS idx,50a % 2 AS a2,51b % 3 AS b3,52MOD(c, 4) AS c4,53MOD(d, 5.5) AS d5554FROM self55"""56)57assert_frame_equal(58out,59pl.DataFrame(60{61"idx": [1, 2, 3, 4, 5],62"a2": [1.5, None, 1.0, 1 / 3, 1.0],63"b3": [0, 1, 2, 0, 1],64"c4": [3, 0, 1, 2, 3],65"d55": [0.0, 0.5, 2.0, None, 3.5],66},67schema_overrides={"idx": pl.UInt32},68),69)707172@pytest.mark.parametrize(73("value", "sqltype", "prec_scale", "expected_value", "expected_dtype"),74[75(64.5, "numeric", "(3,1)", D("64.5"), pl.Decimal(3, 1)),76(512.5, "decimal", "(4,1)", D("512.5"), pl.Decimal(4, 1)),77(512.5, "numeric", "(4,0)", D("512"), pl.Decimal(4, 0)),78(-1024.75, "decimal", "(10,0)", D("-1025"), pl.Decimal(10, 0)),79(-1024.75, "numeric", "(10)", D("-1025"), pl.Decimal(10, 0)),80(-1024.75, "dec", "", D("-1024.75"), pl.Decimal(38, 9)),81],82)83def test_numeric_decimal_type(84value: float,85sqltype: str,86prec_scale: str,87expected_value: D,88expected_dtype: PolarsDataType,89) -> None:90df = pl.DataFrame({"n": [value]})91with pl.SQLContext(df=df) as ctx:92result = ctx.execute(93f"""94SELECT n::{sqltype}{prec_scale} AS "dec" FROM df95"""96)97expected = pl.LazyFrame(98data={"dec": [expected_value]},99schema={"dec": expected_dtype},100)101assert_frame_equal(result, expected)102103104@pytest.mark.parametrize(105("decimals", "expected"),106[107(0, [-8192.0, -4.0, -2.0, 2.0, 4.0, 8193.0]),108(1, [-8192.5, -4.0, -1.5, 2.5, 3.6, 8192.5]),109(2, [-8192.5, -3.96, -1.54, 2.46, 3.6, 8192.5]),110(3, [-8192.499, -3.955, -1.543, 2.457, 3.599, 8192.5]),111(4, [-8192.499, -3.955, -1.5432, 2.4568, 3.599, 8192.5001]),112],113)114def test_round_ndigits(decimals: int, expected: list[float]) -> None:115df = pl.DataFrame(116{"n": [-8192.499, -3.9550, -1.54321, 2.45678, 3.59901, 8192.5001]},117)118with pl.SQLContext(df=df, eager=True) as ctx:119if decimals == 0:120out = ctx.execute("SELECT ROUND(n) AS n FROM df")121assert_series_equal(out["n"], pl.Series("n", values=expected))122123out = ctx.execute(f'SELECT ROUND("n",{decimals}) AS n FROM df')124assert_series_equal(out["n"], pl.Series("n", values=expected))125126127def test_round_ndigits_errors() -> None:128df = pl.DataFrame({"n": [99.999]})129with pl.SQLContext(df=df, eager=True) as ctx:130with pytest.raises(131SQLSyntaxError, match=r"invalid value for ROUND decimals \('!!'\)"132):133ctx.execute("SELECT ROUND(n,'!!') AS n FROM df")134135with pytest.raises(136SQLInterfaceError, match=r"ROUND .* negative decimals value \(-1\)"137):138ctx.execute("SELECT ROUND(n,-1) AS n FROM df")139140with pytest.raises(141SQLSyntaxError, match=r"ROUND expects 1-2 arguments \(found 4\)"142):143ctx.execute("SELECT ROUND(1.2345,6,7,8) AS n FROM df")144145146def test_stddev_variance() -> None:147df = pl.DataFrame(148{149"v1": [-1.0, 0.0, 1.0],150"v2": [5.5, 0.0, 3.0],151"v3": [-10, None, 10],152"v4": [-100.0, 0.0, -50.0],153}154)155with pl.SQLContext(df=df) as ctx:156# note: we support all common aliases for std/var157out = ctx.execute(158"""159SELECT160STDEV(v1) AS "v1_std",161STDDEV(v2) AS "v2_std",162STDEV_SAMP(v3) AS "v3_std",163STDDEV_SAMP(v4) AS "v4_std",164VAR(v1) AS "v1_var",165VARIANCE(v2) AS "v2_var",166VARIANCE(v3) AS "v3_var",167VAR_SAMP(v4) AS "v4_var"168FROM df169"""170).collect()171172assert_frame_equal(173out,174pl.DataFrame(175{176"v1_std": [1.0],177"v2_std": [2.7537852736431],178"v3_std": [14.142135623731],179"v4_std": [50.0],180"v1_var": [1.0],181"v2_var": [7.5833333333333],182"v3_var": [200.0],183"v4_var": [2500.0],184}185),186)187188189