Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/sql/test_numeric.py
8354 views
1
from __future__ import annotations
2
3
from decimal import Decimal as D
4
from typing import TYPE_CHECKING
5
6
import pytest
7
8
import polars as pl
9
from polars.exceptions import SQLInterfaceError, SQLSyntaxError
10
from polars.testing import assert_frame_equal, assert_series_equal
11
12
if TYPE_CHECKING:
13
from polars._typing import PolarsDataType
14
15
16
def test_div() -> None:
17
res = pl.sql(
18
"""
19
SELECT label, DIV(a, b) AS a_div_b, DIV(tbl.b, tbl.a) AS b_div_a
20
FROM (
21
VALUES
22
('a', 20.5, 6),
23
('b', NULL, 12),
24
('c', 10.0, 24),
25
('d', 5.0, NULL),
26
('e', 2.5, 5)
27
) AS tbl(label, a, b)
28
"""
29
).collect()
30
31
assert res.to_dict(as_series=False) == {
32
"label": ["a", "b", "c", "d", "e"],
33
"a_div_b": [3, None, 0, None, 0],
34
"b_div_a": [0, None, 2, None, 2],
35
}
36
37
38
def test_modulo() -> None:
39
df = pl.DataFrame(
40
{
41
"a": [1.5, None, 3.0, 13 / 3, 5.0],
42
"b": [6, 7, 8, 9, 10],
43
"c": [11, 12, 13, 14, 15],
44
"d": [16.5, 17.0, 18.5, None, 20.0],
45
}
46
)
47
out = df.sql(
48
"""
49
SELECT
50
ROW_NUMBER() AS idx,
51
a % 2 AS a2,
52
b % 3 AS b3,
53
MOD(c, 4) AS c4,
54
MOD(d, 5.5) AS d55
55
FROM self
56
"""
57
)
58
assert_frame_equal(
59
out,
60
pl.DataFrame(
61
{
62
"idx": [1, 2, 3, 4, 5],
63
"a2": [1.5, None, 1.0, 1 / 3, 1.0],
64
"b3": [0, 1, 2, 0, 1],
65
"c4": [3, 0, 1, 2, 3],
66
"d55": [0.0, 0.5, 2.0, None, 3.5],
67
},
68
schema_overrides={"idx": pl.UInt32},
69
),
70
)
71
72
73
@pytest.mark.parametrize(
74
("value", "sqltype", "prec_scale", "expected_value", "expected_dtype"),
75
[
76
(64.5, "numeric", "(3,1)", D("64.5"), pl.Decimal(3, 1)),
77
(512.5, "decimal", "(4,1)", D("512.5"), pl.Decimal(4, 1)),
78
(512.5, "numeric", "(4,0)", D("512"), pl.Decimal(4, 0)),
79
(-1024.75, "decimal", "(10,0)", D("-1025"), pl.Decimal(10, 0)),
80
(-1024.75, "numeric", "(10)", D("-1025"), pl.Decimal(10, 0)),
81
(-1024.75, "dec", "", D("-1024.75"), pl.Decimal(38, 9)),
82
],
83
)
84
def test_numeric_decimal_type(
85
value: float,
86
sqltype: str,
87
prec_scale: str,
88
expected_value: D,
89
expected_dtype: PolarsDataType,
90
) -> None:
91
df = pl.DataFrame({"n": [value]})
92
with pl.SQLContext(df=df) as ctx:
93
result = ctx.execute(
94
f"""
95
SELECT n::{sqltype}{prec_scale} AS "dec" FROM df
96
"""
97
)
98
expected = pl.LazyFrame(
99
data={"dec": [expected_value]},
100
schema={"dec": expected_dtype},
101
)
102
assert_frame_equal(result, expected)
103
104
105
@pytest.mark.parametrize(
106
("decimals", "expected"),
107
[
108
(0, [-8192.0, -4.0, -2.0, 2.0, 4.0, 8193.0]),
109
(1, [-8192.5, -4.0, -1.5, 2.5, 3.6, 8192.5]),
110
(2, [-8192.5, -3.96, -1.54, 2.46, 3.6, 8192.5]),
111
(3, [-8192.499, -3.955, -1.543, 2.457, 3.599, 8192.5]),
112
(4, [-8192.499, -3.955, -1.5432, 2.4568, 3.599, 8192.5001]),
113
],
114
)
115
def test_round_ndigits(decimals: int, expected: list[float]) -> None:
116
df = pl.DataFrame(
117
{"n": [-8192.499, -3.9550, -1.54321, 2.45678, 3.59901, 8192.5001]},
118
)
119
with pl.SQLContext(df=df, eager=True) as ctx:
120
if decimals == 0:
121
out = ctx.execute("SELECT ROUND(n) AS n FROM df")
122
assert_series_equal(out["n"], pl.Series("n", values=expected))
123
124
out = ctx.execute(f'SELECT ROUND("n",{decimals}) AS n FROM df')
125
assert_series_equal(out["n"], pl.Series("n", values=expected))
126
127
128
def test_round_ndigits_errors() -> None:
129
df = pl.DataFrame({"n": [99.999]})
130
with pl.SQLContext(df=df, eager=True) as ctx:
131
with pytest.raises(
132
SQLSyntaxError, match=r"invalid value for ROUND decimals \('!!'\)"
133
):
134
ctx.execute("SELECT ROUND(n,'!!') AS n FROM df")
135
136
with pytest.raises(
137
SQLInterfaceError, match=r"ROUND .* negative decimals value \(-1\)"
138
):
139
ctx.execute("SELECT ROUND(n,-1) AS n FROM df")
140
141
with pytest.raises(
142
SQLSyntaxError, match=r"ROUND expects 1-2 arguments \(found 4\)"
143
):
144
ctx.execute("SELECT ROUND(1.2345,6,7,8) AS n FROM df")
145
146
147
def test_stddev_variance() -> None:
148
df = pl.DataFrame(
149
{
150
"v1": [-1.0, 0.0, 1.0],
151
"v2": [5.5, 0.0, 3.0],
152
"v3": [-10, None, 10],
153
"v4": [-100.0, 0.0, -50.0],
154
}
155
)
156
with pl.SQLContext(df=df) as ctx:
157
# note: we support all common aliases for std/var
158
out = ctx.execute(
159
"""
160
SELECT
161
STDEV(v1) AS "v1_std",
162
STDDEV(v2) AS "v2_std",
163
STDEV_SAMP(v3) AS "v3_std",
164
STDDEV_SAMP(v4) AS "v4_std",
165
VAR(v1) AS "v1_var",
166
VARIANCE(v2) AS "v2_var",
167
VARIANCE(v3) AS "v3_var",
168
VAR_SAMP(v4) AS "v4_var"
169
FROM df
170
"""
171
).collect()
172
173
assert_frame_equal(
174
out,
175
pl.DataFrame(
176
{
177
"v1_std": [1.0],
178
"v2_std": [2.7537852736431],
179
"v3_std": [14.142135623731],
180
"v4_std": [50.0],
181
"v1_var": [1.0],
182
"v2_var": [7.5833333333333],
183
"v3_var": [200.0],
184
"v4_var": [2500.0],
185
}
186
),
187
)
188
189