Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/conftest.py
6939 views
1
from __future__ import annotations
2
3
import io
4
from pathlib import PosixPath
5
from typing import Any, Callable, TypeVar, cast
6
7
import pytest
8
9
import polars as pl
10
from polars._typing import PartitioningScheme
11
12
13
def pytest_addoption(parser: pytest.Parser) -> None:
14
parser.addoption(
15
"--cloud-distributed",
16
action="store_true",
17
default=False,
18
help="Run all queries by default of the distributed engine",
19
)
20
21
22
@pytest.fixture(autouse=True)
23
def _patched_cloud(
24
request: pytest.FixtureRequest, monkeypatch: pytest.MonkeyPatch
25
) -> None:
26
if request.config.getoption("--cloud-distributed"):
27
import signal
28
import uuid
29
from pathlib import Path
30
31
from polars_cloud import ClusterContext, DirectQuery, set_compute_context
32
33
TIMEOUT_SECS = 20
34
35
T = TypeVar("T")
36
37
def with_timeout(f: Callable[[], T]) -> T:
38
def handler(signum: Any, frame: Any) -> None:
39
msg = "test timed out"
40
raise TimeoutError(msg)
41
42
signal.signal(signal.SIGALRM, handler)
43
signal.alarm(TIMEOUT_SECS)
44
45
return f()
46
47
ctx = ClusterContext("localhost", insecure=True)
48
set_compute_context(ctx)
49
50
prev_collect = pl.LazyFrame.collect
51
52
def cloud_collect(lf: pl.LazyFrame, *args: Any, **kwargs: Any) -> pl.DataFrame:
53
# issue: cloud client should use pl.QueryOptFlags()
54
if "optimizations" in kwargs:
55
kwargs.pop("optimizations")
56
if "engine" in kwargs:
57
kwargs.pop("engine")
58
59
return prev_collect(
60
with_timeout(
61
lambda: lf.remote(plan_type="plain")
62
.distributed()
63
.execute()
64
.await_result()
65
).lazy()
66
)
67
68
class LazyExe:
69
def __init__(
70
self, query: DirectQuery, prev_tgt: io.BytesIO | None, path: Path
71
) -> None:
72
self.query = query
73
74
self.prev_tgt = prev_tgt
75
self.path = path
76
77
def collect(self) -> pl.DataFrame:
78
# 1. Actually execute the query.
79
with_timeout(lambda: self.query.await_result())
80
81
# 2. If our target was different, write the result into our target
82
# transparently.
83
if self.prev_tgt is not None:
84
is_string = isinstance(self.prev_tgt, (io.StringIO, io.TextIOBase))
85
86
if is_string:
87
with Path.open(self.path, "r") as f:
88
self.prev_tgt.write(f.read()) # type: ignore[arg-type]
89
else:
90
with Path.open(self.path, "rb") as f:
91
self.prev_tgt.write(f.read())
92
93
# delete the temporary file
94
Path(self.path).unlink()
95
96
# Sinks always return an empty DataFrame.
97
return pl.DataFrame({})
98
99
def io_to_path(s: io.IOBase, ext: str) -> Path:
100
path = Path(f"/tmp/pc-{uuid.uuid4()!s}.{ext}")
101
102
with Path.open(path, "wb") as f:
103
bs = s.read()
104
if isinstance(bs, str):
105
bs = bytes(bs, encoding="utf-8")
106
f.write(bs)
107
s.seek(0, 2)
108
return path
109
110
def prepare_scan_sources(src: Any) -> str | Path | list[str | Path]:
111
if isinstance(src, io.IOBase):
112
src = io_to_path(src, ext)
113
elif isinstance(src, bytes):
114
src = io_to_path(io.BytesIO(src), ext)
115
elif isinstance(src, list):
116
for i in range(len(src)):
117
if isinstance(src[i], io.IOBase):
118
src[i] = io_to_path(src[i], ext)
119
elif isinstance(src[i], bytes):
120
src[i] = io_to_path(io.BytesIO(src[i]), ext)
121
122
assert isinstance(src, (str, Path, list)) or (
123
isinstance(src, list) and all(isinstance(x, (str, Path)) for x in src)
124
)
125
126
return src
127
128
def create_cloud_scan(ext: str) -> Callable[..., pl.LazyFrame]:
129
prev_scan = getattr(pl, f"scan_{ext}")
130
prev_scan = cast("Callable[..., pl.LazyFrame]", prev_scan)
131
132
def _(
133
source: io.BytesIO | io.StringIO | str | Path, *args: Any, **kwargs: Any
134
) -> pl.LazyFrame:
135
source = prepare_scan_sources(source) # type: ignore[assignment]
136
return prev_scan(source, *args, **kwargs) # type: ignore[no-any-return]
137
138
return _
139
140
def create_read(ext: str) -> Callable[..., pl.DataFrame]:
141
prev_read = getattr(pl, f"read_{ext}")
142
prev_read = cast("Callable[..., pl.DataFrame]", prev_read)
143
144
def _(
145
source: io.BytesIO | str | Path, *args: Any, **kwargs: Any
146
) -> pl.DataFrame:
147
if ext == "parquet" and kwargs.get("use_pyarrow", False):
148
return prev_read(source, *args, **kwargs) # type: ignore[no-any-return]
149
150
src = prepare_scan_sources(source)
151
return prev_read(src, *args, **kwargs) # type: ignore[no-any-return]
152
153
return _
154
155
def create_cloud_sink(
156
ext: str, unsupported: list[str]
157
) -> Callable[..., pl.LazyFrame | None]:
158
prev_sink = getattr(pl.LazyFrame, f"sink_{ext}")
159
prev_sink = cast("Callable[..., pl.LazyFrame | None]", prev_sink)
160
161
def _(lf: pl.LazyFrame, *args: Any, **kwargs: Any) -> pl.LazyFrame | None:
162
# The cloud client sinks to a "placeholder-path".
163
if args[0] == "placeholder-path" or isinstance(
164
args[0], PartitioningScheme
165
):
166
prev_lazy = kwargs.get("lazy", False)
167
kwargs["lazy"] = True
168
lf = prev_sink(lf, *args, **kwargs)
169
170
class SimpleLazyExe:
171
def __init__(self, query: pl.LazyFrame) -> None:
172
self._ldf = query._ldf
173
self.query = query
174
175
def collect(self, *args: Any, **kwargs: Any) -> pl.DataFrame:
176
return prev_collect(self.query, *args, **kwargs) # type: ignore[no-any-return]
177
178
slf = SimpleLazyExe(lf)
179
if prev_lazy:
180
return slf # type: ignore[return-value]
181
182
slf.collect(
183
optimizations=kwargs.get("optimizations", pl.QueryOptFlags()),
184
)
185
return None
186
187
prev_tgt = None
188
if isinstance(
189
args[0], (io.BytesIO, io.StringIO, io.TextIOBase)
190
) or callable(getattr(args[0], "write", None)):
191
prev_tgt = args[0]
192
args = (f"/tmp/pc-{uuid.uuid4()!s}.{ext}",) + args[1:]
193
elif isinstance(args[0], PosixPath):
194
args = (str(args[0]),) + args[1:]
195
196
lazy = kwargs.pop("lazy", False)
197
198
# these are all the unsupported flags
199
for u in unsupported:
200
_ = kwargs.pop(u, None)
201
202
kwargs["sink_to_single_file"] = "True"
203
204
sink = getattr(
205
lf.remote(plan_type="plain").distributed(), f"sink_{ext}"
206
)
207
q = sink(*args, **kwargs)
208
assert isinstance(q, DirectQuery)
209
query = LazyExe(
210
q,
211
prev_tgt,
212
args[0],
213
)
214
215
if lazy:
216
return query # type: ignore[return-value]
217
218
# If the sink is not lazy, we are expected to collect it.
219
query.collect()
220
return None
221
222
return _
223
224
# fix: these need to become supported somehow
225
BASE_UNSUPPORTED = ["engine", "optimizations", "mkdir", "retries"]
226
for ext in ["parquet", "csv", "ipc", "ndjson"]:
227
monkeypatch.setattr(f"polars.scan_{ext}", create_cloud_scan(ext))
228
monkeypatch.setattr(f"polars.read_{ext}", create_read(ext))
229
monkeypatch.setattr(
230
f"polars.LazyFrame.sink_{ext}",
231
create_cloud_sink(ext, BASE_UNSUPPORTED),
232
)
233
234
monkeypatch.setattr("polars.LazyFrame.collect", cloud_collect)
235
monkeypatch.setenv("POLARS_SKIP_CLIENT_CHECK", "1")
236
237