Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/conftest.py
8416 views
1
from __future__ import annotations
2
3
import gc
4
import os
5
import random
6
import string
7
import sys
8
from contextlib import contextmanager
9
from typing import TYPE_CHECKING, Any, cast
10
11
import numpy as np
12
import pytest
13
14
import polars as pl
15
from polars.testing.parametric import load_profile
16
17
if TYPE_CHECKING:
18
from collections.abc import Callable, Generator
19
from types import ModuleType
20
from typing import Any
21
22
FixtureRequest = Any
23
24
load_profile(
25
profile=os.environ.get("POLARS_HYPOTHESIS_PROFILE", "fast"), # type: ignore[arg-type]
26
)
27
28
# Data type groups
29
SIGNED_INTEGER_DTYPES = [pl.Int8(), pl.Int16(), pl.Int32(), pl.Int64(), pl.Int128()]
30
UNSIGNED_INTEGER_DTYPES = [
31
pl.UInt8(),
32
pl.UInt16(),
33
pl.UInt32(),
34
pl.UInt64(),
35
pl.UInt128(),
36
]
37
INTEGER_DTYPES = SIGNED_INTEGER_DTYPES + UNSIGNED_INTEGER_DTYPES
38
FLOAT_DTYPES = [pl.Float16(), pl.Float32(), pl.Float64()]
39
NUMERIC_DTYPES = INTEGER_DTYPES + FLOAT_DTYPES
40
41
DATETIME_DTYPES = [pl.Datetime("ms"), pl.Datetime("us"), pl.Datetime("ns")]
42
DURATION_DTYPES = [pl.Duration("ms"), pl.Duration("us"), pl.Duration("ns")]
43
TEMPORAL_DTYPES = [*DATETIME_DTYPES, *DURATION_DTYPES, pl.Date(), pl.Time()]
44
45
NESTED_DTYPES = [pl.List, pl.Struct, pl.Array]
46
47
48
@pytest.fixture
49
def partition_limit() -> int:
50
"""The limit at which Polars will start partitioning in debug builds."""
51
return 15
52
53
54
@pytest.fixture
55
def df() -> pl.DataFrame:
56
df = pl.DataFrame(
57
{
58
"bools": [False, True, False],
59
"bools_nulls": [None, True, False],
60
"int": [1, 2, 3],
61
"int_nulls": [1, None, 3],
62
"floats": [1.0, 2.0, 3.0],
63
"floats_nulls": [1.0, None, 3.0],
64
"strings": ["foo", "bar", "ham"],
65
"strings_nulls": ["foo", None, "ham"],
66
"date": [1324, 123, 1234],
67
"datetime": [13241324, 12341256, 12341234],
68
"time": [13241324, 12341256, 12341234],
69
"list_str": [["a", "b", None], ["a"], []],
70
"list_bool": [[True, False, None], [None], []],
71
"list_int": [[1, None, 3], [None], []],
72
"list_flt": [[1.0, None, 3.0], [None], []],
73
}
74
)
75
return df.with_columns(
76
pl.col("date").cast(pl.Date),
77
pl.col("datetime").cast(pl.Datetime),
78
pl.col("strings").cast(pl.Categorical).alias("cat"),
79
pl.col("strings").cast(pl.Enum(["foo", "ham", "bar"])).alias("enum"),
80
pl.col("time").cast(pl.Time),
81
)
82
83
84
@pytest.fixture
85
def df_no_lists(df: pl.DataFrame) -> pl.DataFrame:
86
return df.select(
87
pl.all().exclude(["list_str", "list_int", "list_bool", "list_int", "list_flt"])
88
)
89
90
91
@pytest.fixture
92
def fruits_cars() -> pl.DataFrame:
93
return pl.DataFrame(
94
{
95
"A": [1, 2, 3, 4, 5],
96
"fruits": ["banana", "banana", "apple", "apple", "banana"],
97
"B": [5, 4, 3, 2, 1],
98
"cars": ["beetle", "audi", "beetle", "beetle", "beetle"],
99
},
100
schema_overrides={"A": pl.Int64, "B": pl.Int64},
101
)
102
103
104
@pytest.fixture
105
def str_ints_df() -> pl.DataFrame:
106
n = 1000
107
108
strs = pl.Series("strs", random.choices(string.ascii_lowercase, k=n))
109
strs = pl.select(
110
pl.when(strs == "a")
111
.then(pl.lit(""))
112
.when(strs == "b")
113
.then(None)
114
.otherwise(strs)
115
.alias("strs")
116
).to_series()
117
118
vals = pl.Series("vals", np.random.rand(n))
119
120
return pl.DataFrame([vals, strs])
121
122
123
ISO8601_FORMATS_DATETIME = []
124
125
for T in ["T", " "]:
126
for hms in (
127
[
128
f"{T}%H:%M:%S",
129
f"{T}%H%M%S",
130
f"{T}%H:%M",
131
f"{T}%H%M",
132
]
133
+ [f"{T}%H:%M:%S.{fraction}" for fraction in ["%9f", "%6f", "%3f"]]
134
+ [f"{T}%H%M%S.{fraction}" for fraction in ["%9f", "%6f", "%3f"]]
135
+ [""]
136
):
137
for date_sep in ("/", "-"):
138
fmt = f"%Y{date_sep}%m{date_sep}%d{hms}"
139
ISO8601_FORMATS_DATETIME.append(fmt)
140
141
142
@pytest.fixture(params=ISO8601_FORMATS_DATETIME)
143
def iso8601_format_datetime(request: pytest.FixtureRequest) -> list[str]:
144
return cast("list[str]", request.param)
145
146
147
ISO8601_TZ_AWARE_FORMATS_DATETIME = []
148
149
for T in ["T", " "]:
150
for hms in (
151
[
152
f"{T}%H:%M:%S",
153
f"{T}%H%M%S",
154
f"{T}%H:%M",
155
f"{T}%H%M",
156
]
157
+ [f"{T}%H:%M:%S.{fraction}" for fraction in ["%9f", "%6f", "%3f"]]
158
+ [f"{T}%H%M%S.{fraction}" for fraction in ["%9f", "%6f", "%3f"]]
159
):
160
for date_sep in ("/", "-"):
161
fmt = f"%Y{date_sep}%m{date_sep}%d{hms}%#z"
162
ISO8601_TZ_AWARE_FORMATS_DATETIME.append(fmt)
163
164
165
@pytest.fixture(params=ISO8601_TZ_AWARE_FORMATS_DATETIME)
166
def iso8601_tz_aware_format_datetime(request: pytest.FixtureRequest) -> list[str]:
167
return cast("list[str]", request.param)
168
169
170
ISO8601_FORMATS_DATE = []
171
172
for date_sep in ("/", "-"):
173
fmt = f"%Y{date_sep}%m{date_sep}%d"
174
ISO8601_FORMATS_DATE.append(fmt)
175
176
177
@pytest.fixture(params=ISO8601_FORMATS_DATE)
178
def iso8601_format_date(request: pytest.FixtureRequest) -> list[str]:
179
return cast("list[str]", request.param)
180
181
182
class MemoryUsage:
183
"""
184
Provide an API for measuring peak memory usage.
185
186
Memory from PyArrow is not tracked at the moment.
187
"""
188
189
def reset_tracking(self) -> None:
190
"""Reset tracking to zero."""
191
# gc.collect()
192
# tracemalloc.stop()
193
# tracemalloc.start()
194
# assert self.get_peak() < 100_000
195
196
def get_current(self) -> int:
197
"""
198
Return currently allocated memory, in bytes.
199
200
This only tracks allocations since this object was created or
201
``reset_tracking()`` was called, whichever is later.
202
"""
203
return 0
204
# tracemalloc.get_traced_memory()[0]
205
206
def get_peak(self) -> int:
207
"""
208
Return peak allocated memory, in bytes.
209
210
This returns peak allocations since this object was created or
211
``reset_tracking()`` was called, whichever is later.
212
"""
213
return 0
214
# tracemalloc.get_traced_memory()[1]
215
216
217
# The bizarre syntax is from
218
# https://github.com/pytest-dev/pytest/issues/1368#issuecomment-2344450259 - we
219
# need to mark any test using this fixture as slow because we have a sleep
220
# added to work around a CPython bug, see the end of the function.
221
@pytest.fixture(params=[pytest.param(0, marks=pytest.mark.slow)])
222
def memory_usage_without_pyarrow() -> Generator[MemoryUsage, Any, Any]:
223
"""
224
Provide an API for measuring peak memory usage.
225
226
Not thread-safe: there should only be one instance of MemoryUsage at any
227
given time.
228
229
Memory usage from PyArrow is not tracked.
230
"""
231
if not pl._plr._debug:
232
pytest.skip("Memory usage only available in debug/dev builds.")
233
234
if os.getenv("POLARS_FORCE_ASYNC", "0") == "1":
235
pytest.skip("Hangs when combined with async glob")
236
237
if sys.platform == "win32":
238
# abi3 wheels don't have the tracemalloc C APIs, which breaks linking
239
# on Windows.
240
pytest.skip("Windows not supported at the moment.")
241
242
gc.collect()
243
try:
244
yield MemoryUsage()
245
finally:
246
gc.collect()
247
# gc.collect()
248
# tracemalloc.start()
249
# try:
250
# yield MemoryUsage()
251
# finally:
252
# # Workaround for https://github.com/python/cpython/issues/128679
253
# time.sleep(1)
254
# gc.collect()
255
#
256
# tracemalloc.stop()
257
258
259
@contextmanager
260
def mock_module_import(
261
name: str, module: ModuleType, *, replace_if_exists: bool = False
262
) -> Generator[None, None, None]:
263
"""
264
Mock an optional module import for the duration of a context.
265
266
Parameters
267
----------
268
name
269
The name of the module to mock.
270
module
271
A ModuleType instance representing the mocked module.
272
replace_if_exists
273
Whether to replace the module if it already exists in `sys.modules` (defaults to
274
False, meaning that if the module is already imported, it will not be replaced).
275
"""
276
if (original := sys.modules.get(name, None)) is not None and not replace_if_exists:
277
yield
278
else:
279
sys.modules[name] = module
280
try:
281
yield
282
finally:
283
if original is not None:
284
sys.modules[name] = original
285
else:
286
del sys.modules[name]
287
288
289
def time_func(func: Callable[[], Any], *, iterations: int = 3) -> float:
290
"""Minimum time over 3 iterations."""
291
from time import perf_counter
292
293
times = []
294
for _ in range(iterations):
295
t = perf_counter()
296
func()
297
times.append(perf_counter() - t)
298
times = [min(times)]
299
300
return min(times)
301
302