Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/conftest.py
8417 views
1
from __future__ import annotations
2
3
import io
4
from pathlib import PosixPath
5
from typing import TYPE_CHECKING, Any, TypeVar, cast
6
7
import pytest
8
9
import polars as pl
10
11
if TYPE_CHECKING:
12
from collections.abc import Callable, Generator
13
14
15
def pytest_addoption(parser: pytest.Parser) -> None:
16
parser.addoption(
17
"--cloud-distributed",
18
action="store_true",
19
default=False,
20
help="Run all queries by default of the distributed engine",
21
)
22
23
24
@pytest.fixture(autouse=True)
25
def _patched_cloud(
26
request: pytest.FixtureRequest, plmonkeypatch: PlMonkeyPatch
27
) -> None:
28
if request.config.getoption("--cloud-distributed"):
29
import signal
30
import uuid
31
from pathlib import Path
32
33
from polars_cloud import ClusterContext, DirectQuery, set_compute_context
34
35
TIMEOUT_SECS = 20
36
37
T = TypeVar("T")
38
39
def with_timeout(f: Callable[[], T]) -> T:
40
def handler(signum: Any, frame: Any) -> None:
41
msg = "test timed out"
42
raise TimeoutError(msg)
43
44
signal.signal(signal.SIGALRM, handler)
45
signal.alarm(TIMEOUT_SECS)
46
47
return f()
48
49
ctx = ClusterContext("localhost", insecure=True)
50
set_compute_context(ctx)
51
52
prev_collect = pl.LazyFrame.collect
53
54
def cloud_collect(lf: pl.LazyFrame, *args: Any, **kwargs: Any) -> pl.DataFrame:
55
# issue: cloud client should use pl.QueryOptFlags()
56
if "optimizations" in kwargs:
57
kwargs.pop("optimizations")
58
if "engine" in kwargs:
59
kwargs.pop("engine")
60
61
return prev_collect(
62
with_timeout(
63
lambda: (
64
lf.remote(plan_type="plain")
65
.distributed()
66
.execute()
67
.await_result()
68
)
69
).lazy()
70
)
71
72
class LazyExe:
73
def __init__(
74
self, query: DirectQuery, prev_tgt: io.BytesIO | None, path: Path
75
) -> None:
76
self.query = query
77
78
self.prev_tgt = prev_tgt
79
self.path = path
80
81
def collect(self) -> pl.DataFrame:
82
# 1. Actually execute the query.
83
with_timeout(lambda: self.query.await_result())
84
85
# 2. If our target was different, write the result into our target
86
# transparently.
87
if self.prev_tgt is not None:
88
is_string = isinstance(self.prev_tgt, (io.StringIO, io.TextIOBase))
89
90
if is_string:
91
with Path.open(self.path, "r") as f:
92
self.prev_tgt.write(f.read()) # type: ignore[arg-type]
93
else:
94
with Path.open(self.path, "rb") as f:
95
self.prev_tgt.write(f.read())
96
97
# delete the temporary file
98
Path(self.path).unlink()
99
100
# Sinks always return an empty DataFrame.
101
return pl.DataFrame({})
102
103
def io_to_path(s: io.IOBase, ext: str) -> Path:
104
path = Path(f"/tmp/pc-{uuid.uuid4()!s}.{ext}")
105
106
with Path.open(path, "wb") as f:
107
bs = s.read()
108
if isinstance(bs, str):
109
bs = bytes(bs, encoding="utf-8")
110
f.write(bs)
111
s.seek(0, 2)
112
return path
113
114
def prepare_scan_sources(src: Any) -> str | Path | list[str | Path]:
115
if isinstance(src, io.IOBase):
116
src = io_to_path(src, ext)
117
elif isinstance(src, bytes):
118
src = io_to_path(io.BytesIO(src), ext)
119
elif isinstance(src, list):
120
for i in range(len(src)):
121
if isinstance(src[i], io.IOBase):
122
src[i] = io_to_path(src[i], ext)
123
elif isinstance(src[i], bytes):
124
src[i] = io_to_path(io.BytesIO(src[i]), ext)
125
126
assert isinstance(src, (str, Path, list)) or (
127
isinstance(src, list) and all(isinstance(x, (str, Path)) for x in src)
128
)
129
130
return src
131
132
def create_cloud_scan(ext: str) -> Callable[..., pl.LazyFrame]:
133
prev_scan = getattr(pl, f"scan_{ext}")
134
prev_scan = cast("Callable[..., pl.LazyFrame]", prev_scan)
135
136
def _(
137
source: io.BytesIO | io.StringIO | str | Path, *args: Any, **kwargs: Any
138
) -> pl.LazyFrame:
139
source = prepare_scan_sources(source) # type: ignore[assignment]
140
return prev_scan(source, *args, **kwargs) # type: ignore[no-any-return]
141
142
return _
143
144
def create_read(ext: str) -> Callable[..., pl.DataFrame]:
145
prev_read = getattr(pl, f"read_{ext}")
146
prev_read = cast("Callable[..., pl.DataFrame]", prev_read)
147
148
def _(
149
source: io.BytesIO | str | Path, *args: Any, **kwargs: Any
150
) -> pl.DataFrame:
151
if ext == "parquet" and kwargs.get("use_pyarrow", False):
152
return prev_read(source, *args, **kwargs) # type: ignore[no-any-return]
153
154
src = prepare_scan_sources(source)
155
return prev_read(src, *args, **kwargs) # type: ignore[no-any-return]
156
157
return _
158
159
def create_cloud_sink(
160
ext: str, unsupported: list[str]
161
) -> Callable[..., pl.LazyFrame | None]:
162
prev_sink = getattr(pl.LazyFrame, f"sink_{ext}")
163
prev_sink = cast("Callable[..., pl.LazyFrame | None]", prev_sink)
164
165
def _(lf: pl.LazyFrame, *args: Any, **kwargs: Any) -> pl.LazyFrame | None:
166
# The cloud client sinks to a "placeholder-path".
167
if args[0] == "placeholder-path" or isinstance(args[0], pl.PartitionBy):
168
prev_lazy = kwargs.get("lazy", False)
169
kwargs["lazy"] = True
170
lf = prev_sink(lf, *args, **kwargs) # type: ignore[assignment]
171
172
class SimpleLazyExe:
173
def __init__(self, query: pl.LazyFrame) -> None:
174
self._ldf = query._ldf
175
self.query = query
176
177
def collect(self, *args: Any, **kwargs: Any) -> pl.DataFrame:
178
return prev_collect(self.query, *args, **kwargs) # type: ignore[no-any-return]
179
180
slf = SimpleLazyExe(lf)
181
if prev_lazy:
182
return slf # type: ignore[return-value]
183
184
slf.collect(
185
optimizations=kwargs.get("optimizations", pl.QueryOptFlags()),
186
)
187
return None
188
189
prev_tgt = None
190
if isinstance(
191
args[0], (io.BytesIO, io.StringIO, io.TextIOBase)
192
) or callable(getattr(args[0], "write", None)):
193
prev_tgt = args[0]
194
args = (f"/tmp/pc-{uuid.uuid4()!s}.{ext}",) + args[1:]
195
elif isinstance(args[0], PosixPath):
196
args = (str(args[0]),) + args[1:]
197
198
lazy = kwargs.pop("lazy", False)
199
200
# these are all the unsupported flags
201
for u in unsupported:
202
_ = kwargs.pop(u, None)
203
204
kwargs["sink_to_single_file"] = "True"
205
206
sink = getattr(
207
lf.remote(plan_type="plain").distributed(), f"sink_{ext}"
208
)
209
q = sink(*args, **kwargs)
210
assert isinstance(q, DirectQuery)
211
query = LazyExe(
212
q,
213
prev_tgt,
214
args[0],
215
)
216
217
if lazy:
218
return query # type: ignore[return-value]
219
220
# If the sink is not lazy, we are expected to collect it.
221
query.collect()
222
return None
223
224
return _
225
226
# fix: these need to become supported somehow
227
BASE_UNSUPPORTED = ["engine", "optimizations", "mkdir", "retries"]
228
for ext in ["parquet", "csv", "ipc", "ndjson"]:
229
plmonkeypatch.setattr(f"polars.scan_{ext}", create_cloud_scan(ext))
230
plmonkeypatch.setattr(f"polars.read_{ext}", create_read(ext))
231
plmonkeypatch.setattr(
232
f"polars.LazyFrame.sink_{ext}",
233
create_cloud_sink(ext, BASE_UNSUPPORTED),
234
)
235
236
plmonkeypatch.setattr("polars.LazyFrame.collect", cloud_collect)
237
plmonkeypatch.setenv("POLARS_SKIP_CLIENT_CHECK", "1")
238
239
240
class PlMonkeyPatch(pytest.MonkeyPatch): # type: ignore[misc]
241
"""A wrapper of pytest.MonkeyPatch that updates Polars when an env var changes."""
242
243
def __init__(self, *args: Any, **kwargs: Any) -> None:
244
super().__init__(*args, **kwargs)
245
246
def setenv(self, name: str, value: str, prepend: str | None = None) -> None:
247
super().setenv(name, value, prepend)
248
if name.startswith("POLARS_"):
249
pl.Config.reload_env_vars()
250
251
def undo(self) -> None:
252
super().undo()
253
pl.Config.reload_env_vars()
254
255
256
@pytest.fixture
257
def plmonkeypatch() -> Generator[PlMonkeyPatch, None, None]:
258
"""A wrapper of pytest.plmonkeypatch that updates Polars when an env var changes."""
259
mpatch = PlMonkeyPatch()
260
yield mpatch
261
mpatch.undo()
262
263