Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/conftest.py
6939 views
1
from __future__ import annotations
2
3
import gc
4
import os
5
import random
6
import string
7
import sys
8
from contextlib import contextmanager
9
from typing import TYPE_CHECKING, Any, Callable, cast
10
11
import numpy as np
12
import pytest
13
14
import polars as pl
15
from polars.testing.parametric import load_profile
16
17
if TYPE_CHECKING:
18
from collections.abc import Generator
19
from types import ModuleType
20
from typing import Any
21
22
FixtureRequest = Any
23
24
load_profile(
25
profile=os.environ.get("POLARS_HYPOTHESIS_PROFILE", "fast"), # type: ignore[arg-type]
26
)
27
28
# Data type groups
29
SIGNED_INTEGER_DTYPES = [pl.Int8(), pl.Int16(), pl.Int32(), pl.Int64(), pl.Int128()]
30
UNSIGNED_INTEGER_DTYPES = [pl.UInt8(), pl.UInt16(), pl.UInt32(), pl.UInt64()]
31
INTEGER_DTYPES = SIGNED_INTEGER_DTYPES + UNSIGNED_INTEGER_DTYPES
32
FLOAT_DTYPES = [pl.Float32(), pl.Float64()]
33
NUMERIC_DTYPES = INTEGER_DTYPES + FLOAT_DTYPES
34
35
DATETIME_DTYPES = [pl.Datetime("ms"), pl.Datetime("us"), pl.Datetime("ns")]
36
DURATION_DTYPES = [pl.Duration("ms"), pl.Duration("us"), pl.Duration("ns")]
37
TEMPORAL_DTYPES = [*DATETIME_DTYPES, *DURATION_DTYPES, pl.Date(), pl.Time()]
38
39
NESTED_DTYPES = [pl.List, pl.Struct, pl.Array]
40
41
42
@pytest.fixture
43
def partition_limit() -> int:
44
"""The limit at which Polars will start partitioning in debug builds."""
45
return 15
46
47
48
@pytest.fixture
49
def df() -> pl.DataFrame:
50
df = pl.DataFrame(
51
{
52
"bools": [False, True, False],
53
"bools_nulls": [None, True, False],
54
"int": [1, 2, 3],
55
"int_nulls": [1, None, 3],
56
"floats": [1.0, 2.0, 3.0],
57
"floats_nulls": [1.0, None, 3.0],
58
"strings": ["foo", "bar", "ham"],
59
"strings_nulls": ["foo", None, "ham"],
60
"date": [1324, 123, 1234],
61
"datetime": [13241324, 12341256, 12341234],
62
"time": [13241324, 12341256, 12341234],
63
"list_str": [["a", "b", None], ["a"], []],
64
"list_bool": [[True, False, None], [None], []],
65
"list_int": [[1, None, 3], [None], []],
66
"list_flt": [[1.0, None, 3.0], [None], []],
67
}
68
)
69
return df.with_columns(
70
pl.col("date").cast(pl.Date),
71
pl.col("datetime").cast(pl.Datetime),
72
pl.col("strings").cast(pl.Categorical).alias("cat"),
73
pl.col("strings").cast(pl.Enum(["foo", "ham", "bar"])).alias("enum"),
74
pl.col("time").cast(pl.Time),
75
)
76
77
78
@pytest.fixture
79
def df_no_lists(df: pl.DataFrame) -> pl.DataFrame:
80
return df.select(
81
pl.all().exclude(["list_str", "list_int", "list_bool", "list_int", "list_flt"])
82
)
83
84
85
@pytest.fixture
86
def fruits_cars() -> pl.DataFrame:
87
return pl.DataFrame(
88
{
89
"A": [1, 2, 3, 4, 5],
90
"fruits": ["banana", "banana", "apple", "apple", "banana"],
91
"B": [5, 4, 3, 2, 1],
92
"cars": ["beetle", "audi", "beetle", "beetle", "beetle"],
93
},
94
schema_overrides={"A": pl.Int64, "B": pl.Int64},
95
)
96
97
98
@pytest.fixture
99
def str_ints_df() -> pl.DataFrame:
100
n = 1000
101
102
strs = pl.Series("strs", random.choices(string.ascii_lowercase, k=n))
103
strs = pl.select(
104
pl.when(strs == "a")
105
.then(pl.lit(""))
106
.when(strs == "b")
107
.then(None)
108
.otherwise(strs)
109
.alias("strs")
110
).to_series()
111
112
vals = pl.Series("vals", np.random.rand(n))
113
114
return pl.DataFrame([vals, strs])
115
116
117
ISO8601_FORMATS_DATETIME = []
118
119
for T in ["T", " "]:
120
for hms in (
121
[
122
f"{T}%H:%M:%S",
123
f"{T}%H%M%S",
124
f"{T}%H:%M",
125
f"{T}%H%M",
126
]
127
+ [f"{T}%H:%M:%S.{fraction}" for fraction in ["%9f", "%6f", "%3f"]]
128
+ [f"{T}%H%M%S.{fraction}" for fraction in ["%9f", "%6f", "%3f"]]
129
+ [""]
130
):
131
for date_sep in ("/", "-"):
132
fmt = f"%Y{date_sep}%m{date_sep}%d{hms}"
133
ISO8601_FORMATS_DATETIME.append(fmt)
134
135
136
@pytest.fixture(params=ISO8601_FORMATS_DATETIME)
137
def iso8601_format_datetime(request: pytest.FixtureRequest) -> list[str]:
138
return cast(list[str], request.param)
139
140
141
ISO8601_TZ_AWARE_FORMATS_DATETIME = []
142
143
for T in ["T", " "]:
144
for hms in (
145
[
146
f"{T}%H:%M:%S",
147
f"{T}%H%M%S",
148
f"{T}%H:%M",
149
f"{T}%H%M",
150
]
151
+ [f"{T}%H:%M:%S.{fraction}" for fraction in ["%9f", "%6f", "%3f"]]
152
+ [f"{T}%H%M%S.{fraction}" for fraction in ["%9f", "%6f", "%3f"]]
153
):
154
for date_sep in ("/", "-"):
155
fmt = f"%Y{date_sep}%m{date_sep}%d{hms}%#z"
156
ISO8601_TZ_AWARE_FORMATS_DATETIME.append(fmt)
157
158
159
@pytest.fixture(params=ISO8601_TZ_AWARE_FORMATS_DATETIME)
160
def iso8601_tz_aware_format_datetime(request: pytest.FixtureRequest) -> list[str]:
161
return cast(list[str], request.param)
162
163
164
ISO8601_FORMATS_DATE = []
165
166
for date_sep in ("/", "-"):
167
fmt = f"%Y{date_sep}%m{date_sep}%d"
168
ISO8601_FORMATS_DATE.append(fmt)
169
170
171
@pytest.fixture(params=ISO8601_FORMATS_DATE)
172
def iso8601_format_date(request: pytest.FixtureRequest) -> list[str]:
173
return cast(list[str], request.param)
174
175
176
class MemoryUsage:
177
"""
178
Provide an API for measuring peak memory usage.
179
180
Memory from PyArrow is not tracked at the moment.
181
"""
182
183
def reset_tracking(self) -> None:
184
"""Reset tracking to zero."""
185
# gc.collect()
186
# tracemalloc.stop()
187
# tracemalloc.start()
188
# assert self.get_peak() < 100_000
189
190
def get_current(self) -> int:
191
"""
192
Return currently allocated memory, in bytes.
193
194
This only tracks allocations since this object was created or
195
``reset_tracking()`` was called, whichever is later.
196
"""
197
return 0
198
# tracemalloc.get_traced_memory()[0]
199
200
def get_peak(self) -> int:
201
"""
202
Return peak allocated memory, in bytes.
203
204
This returns peak allocations since this object was created or
205
``reset_tracking()`` was called, whichever is later.
206
"""
207
return 0
208
# tracemalloc.get_traced_memory()[1]
209
210
211
# The bizarre syntax is from
212
# https://github.com/pytest-dev/pytest/issues/1368#issuecomment-2344450259 - we
213
# need to mark any test using this fixture as slow because we have a sleep
214
# added to work around a CPython bug, see the end of the function.
215
@pytest.fixture(params=[pytest.param(0, marks=pytest.mark.slow)])
216
def memory_usage_without_pyarrow() -> Generator[MemoryUsage, Any, Any]:
217
"""
218
Provide an API for measuring peak memory usage.
219
220
Not thread-safe: there should only be one instance of MemoryUsage at any
221
given time.
222
223
Memory usage from PyArrow is not tracked.
224
"""
225
if not pl.polars._debug: # type: ignore[attr-defined]
226
pytest.skip("Memory usage only available in debug/dev builds.")
227
228
if os.getenv("POLARS_FORCE_ASYNC", "0") == "1":
229
pytest.skip("Hangs when combined with async glob")
230
231
if sys.platform == "win32":
232
# abi3 wheels don't have the tracemalloc C APIs, which breaks linking
233
# on Windows.
234
pytest.skip("Windows not supported at the moment.")
235
236
gc.collect()
237
try:
238
yield MemoryUsage()
239
finally:
240
gc.collect()
241
# gc.collect()
242
# tracemalloc.start()
243
# try:
244
# yield MemoryUsage()
245
# finally:
246
# # Workaround for https://github.com/python/cpython/issues/128679
247
# time.sleep(1)
248
# gc.collect()
249
#
250
# tracemalloc.stop()
251
252
253
@contextmanager
254
def mock_module_import(
255
name: str, module: ModuleType, *, replace_if_exists: bool = False
256
) -> Generator[None, None, None]:
257
"""
258
Mock an optional module import for the duration of a context.
259
260
Parameters
261
----------
262
name
263
The name of the module to mock.
264
module
265
A ModuleType instance representing the mocked module.
266
replace_if_exists
267
Whether to replace the module if it already exists in `sys.modules` (defaults to
268
False, meaning that if the module is already imported, it will not be replaced).
269
"""
270
if (original := sys.modules.get(name, None)) is not None and not replace_if_exists:
271
yield
272
else:
273
sys.modules[name] = module
274
try:
275
yield
276
finally:
277
if original is not None:
278
sys.modules[name] = original
279
else:
280
del sys.modules[name]
281
282
283
def time_func(func: Callable[[], Any], *, iterations: int = 3) -> float:
284
"""Minimum time over 3 iterations."""
285
from time import perf_counter
286
287
times = []
288
for _ in range(iterations):
289
t = perf_counter()
290
func()
291
times.append(perf_counter() - t)
292
times = [min(times)]
293
294
return min(times)
295
296