Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/sql/test_numeric.py
6939 views
1
from __future__ import annotations
2
3
from decimal import Decimal as D
4
from typing import TYPE_CHECKING
5
6
import pytest
7
8
import polars as pl
9
from polars.exceptions import SQLInterfaceError, SQLSyntaxError
10
from polars.testing import assert_frame_equal, assert_series_equal
11
12
if TYPE_CHECKING:
13
from polars._typing import PolarsDataType
14
15
16
def test_div() -> None:
17
res = pl.sql(
18
"""
19
SELECT label, DIV(a, b) AS a_div_b, DIV(tbl.b, tbl.a) AS b_div_a
20
FROM (
21
VALUES
22
('a', 20.5, 6),
23
('b', NULL, 12),
24
('c', 10.0, 24),
25
('d', 5.0, NULL),
26
('e', 2.5, 5)
27
) AS tbl(label, a, b)
28
"""
29
).collect()
30
31
assert res.to_dict(as_series=False) == {
32
"label": ["a", "b", "c", "d", "e"],
33
"a_div_b": [3, None, 0, None, 0],
34
"b_div_a": [0, None, 2, None, 2],
35
}
36
37
38
def test_modulo() -> None:
39
df = pl.DataFrame(
40
{
41
"a": [1.5, None, 3.0, 13 / 3, 5.0],
42
"b": [6, 7, 8, 9, 10],
43
"c": [11, 12, 13, 14, 15],
44
"d": [16.5, 17.0, 18.5, None, 20.0],
45
}
46
)
47
out = df.sql(
48
"""
49
SELECT
50
a % 2 AS a2,
51
b % 3 AS b3,
52
MOD(c, 4) AS c4,
53
MOD(d, 5.5) AS d55
54
FROM self
55
"""
56
)
57
assert_frame_equal(
58
out,
59
pl.DataFrame(
60
{
61
"a2": [1.5, None, 1.0, 1 / 3, 1.0],
62
"b3": [0, 1, 2, 0, 1],
63
"c4": [3, 0, 1, 2, 3],
64
"d55": [0.0, 0.5, 2.0, None, 3.5],
65
}
66
),
67
)
68
69
70
@pytest.mark.parametrize(
71
("value", "sqltype", "prec_scale", "expected_value", "expected_dtype"),
72
[
73
(64.5, "numeric", "(3,1)", D("64.5"), pl.Decimal(3, 1)),
74
(512.5, "decimal", "(4,1)", D("512.5"), pl.Decimal(4, 1)),
75
(512.5, "numeric", "(4,0)", D("512"), pl.Decimal(4, 0)),
76
(-1024.75, "decimal", "(10,0)", D("-1024"), pl.Decimal(10, 0)),
77
(-1024.75, "numeric", "(10)", D("-1024"), pl.Decimal(10, 0)),
78
(-1024.75, "dec", "", D("-1024.75"), pl.Decimal(38, 9)),
79
],
80
)
81
def test_numeric_decimal_type(
82
value: float,
83
sqltype: str,
84
prec_scale: str,
85
expected_value: D,
86
expected_dtype: PolarsDataType,
87
) -> None:
88
df = pl.DataFrame({"n": [value]})
89
with pl.SQLContext(df=df) as ctx:
90
result = ctx.execute(
91
f"""
92
SELECT n::{sqltype}{prec_scale} AS "dec" FROM df
93
"""
94
)
95
expected = pl.LazyFrame(
96
data={"dec": [expected_value]},
97
schema={"dec": expected_dtype},
98
)
99
assert_frame_equal(result, expected)
100
101
102
@pytest.mark.parametrize(
103
("decimals", "expected"),
104
[
105
(0, [-8192.0, -4.0, -2.0, 2.0, 4.0, 8193.0]),
106
(1, [-8192.5, -4.0, -1.5, 2.5, 3.6, 8192.5]),
107
(2, [-8192.5, -3.96, -1.54, 2.46, 3.6, 8192.5]),
108
(3, [-8192.499, -3.955, -1.543, 2.457, 3.599, 8192.5]),
109
(4, [-8192.499, -3.955, -1.5432, 2.4568, 3.599, 8192.5001]),
110
],
111
)
112
def test_round_ndigits(decimals: int, expected: list[float]) -> None:
113
df = pl.DataFrame(
114
{"n": [-8192.499, -3.9550, -1.54321, 2.45678, 3.59901, 8192.5001]},
115
)
116
with pl.SQLContext(df=df, eager=True) as ctx:
117
if decimals == 0:
118
out = ctx.execute("SELECT ROUND(n) AS n FROM df")
119
assert_series_equal(out["n"], pl.Series("n", values=expected))
120
121
out = ctx.execute(f'SELECT ROUND("n",{decimals}) AS n FROM df')
122
assert_series_equal(out["n"], pl.Series("n", values=expected))
123
124
125
def test_round_ndigits_errors() -> None:
126
df = pl.DataFrame({"n": [99.999]})
127
with pl.SQLContext(df=df, eager=True) as ctx:
128
with pytest.raises(
129
SQLSyntaxError, match=r"invalid value for ROUND decimals \('!!'\)"
130
):
131
ctx.execute("SELECT ROUND(n,'!!') AS n FROM df")
132
133
with pytest.raises(
134
SQLInterfaceError, match=r"ROUND .* negative decimals value \(-1\)"
135
):
136
ctx.execute("SELECT ROUND(n,-1) AS n FROM df")
137
138
with pytest.raises(
139
SQLSyntaxError, match=r"ROUND expects 1-2 arguments \(found 4\)"
140
):
141
ctx.execute("SELECT ROUND(1.2345,6,7,8) AS n FROM df")
142
143
144
def test_stddev_variance() -> None:
145
df = pl.DataFrame(
146
{
147
"v1": [-1.0, 0.0, 1.0],
148
"v2": [5.5, 0.0, 3.0],
149
"v3": [-10, None, 10],
150
"v4": [-100.0, 0.0, -50.0],
151
}
152
)
153
with pl.SQLContext(df=df) as ctx:
154
# note: we support all common aliases for std/var
155
out = ctx.execute(
156
"""
157
SELECT
158
STDEV(v1) AS "v1_std",
159
STDDEV(v2) AS "v2_std",
160
STDEV_SAMP(v3) AS "v3_std",
161
STDDEV_SAMP(v4) AS "v4_std",
162
VAR(v1) AS "v1_var",
163
VARIANCE(v2) AS "v2_var",
164
VARIANCE(v3) AS "v3_var",
165
VAR_SAMP(v4) AS "v4_var"
166
FROM df
167
"""
168
).collect()
169
170
assert_frame_equal(
171
out,
172
pl.DataFrame(
173
{
174
"v1_std": [1.0],
175
"v2_std": [2.7537852736431],
176
"v3_std": [14.142135623731],
177
"v4_std": [50.0],
178
"v1_var": [1.0],
179
"v2_var": [7.5833333333333],
180
"v3_var": [200.0],
181
"v4_var": [2500.0],
182
}
183
),
184
)
185
186