Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/io/test_io_plugin.py
6939 views
1
from __future__ import annotations
2
3
import datetime
4
import io
5
import subprocess
6
import sys
7
from typing import TYPE_CHECKING
8
9
import numpy as np
10
import pytest
11
12
import polars as pl
13
from polars.io.plugins import register_io_source
14
from polars.testing import assert_frame_equal, assert_series_equal
15
16
if TYPE_CHECKING:
17
from collections.abc import Iterator
18
19
20
def test_io_plugin_predicate_no_serialization_21130() -> None:
21
def custom_io() -> pl.LazyFrame:
22
def source_generator(
23
with_columns: list[str] | None,
24
predicate: pl.Expr | None,
25
n_rows: int | None,
26
batch_size: int | None,
27
) -> Iterator[pl.DataFrame]:
28
df = pl.DataFrame(
29
{"json_val": ['{"a":"1"}', None, '{"a":2}', '{"a":2.1}', '{"a":true}']}
30
)
31
if predicate is not None:
32
df = df.filter(predicate)
33
if batch_size and df.height > batch_size:
34
yield from df.iter_slices(n_rows=batch_size)
35
else:
36
yield df
37
38
return register_io_source(
39
io_source=source_generator, schema={"json_val": pl.String}
40
)
41
42
lf = custom_io()
43
assert lf.filter(
44
pl.col("json_val").str.json_path_match("$.a").is_in(["1"])
45
).collect().to_dict(as_series=False) == {"json_val": ['{"a":"1"}']}
46
47
48
def test_defer_validate_true() -> None:
49
lf = pl.defer(
50
lambda: pl.DataFrame({"a": np.ones(3)}),
51
schema={"a": pl.Boolean},
52
validate_schema=True,
53
)
54
with pytest.raises(pl.exceptions.SchemaError):
55
lf.collect()
56
57
58
@pytest.mark.may_fail_cloud
59
@pytest.mark.may_fail_auto_streaming # IO plugin validate=False schema mismatch
60
def test_defer_validate_false() -> None:
61
lf = pl.defer(
62
lambda: pl.DataFrame({"a": np.ones(3)}),
63
schema={"a": pl.Boolean},
64
validate_schema=False,
65
)
66
assert lf.collect().to_dict(as_series=False) == {"a": [1.0, 1.0, 1.0]}
67
68
69
def test_empty_iterator_io_plugin() -> None:
70
def _io_source(
71
with_columns: list[str] | None,
72
predicate: pl.Expr | None,
73
n_rows: int | None,
74
batch_size: int | None,
75
) -> Iterator[pl.DataFrame]:
76
yield from []
77
78
schema = pl.Schema([("a", pl.Int64)])
79
df = register_io_source(_io_source, schema=schema)
80
assert df.collect().schema == schema
81
82
83
def test_scan_lines() -> None:
84
def scan_lines(f: io.BytesIO) -> pl.LazyFrame:
85
schema = pl.Schema({"lines": pl.String()})
86
87
def generator(
88
with_columns: list[str] | None,
89
predicate: pl.Expr | None,
90
n_rows: int | None,
91
batch_size: int | None,
92
) -> Iterator[pl.DataFrame]:
93
x = f
94
if batch_size is None:
95
batch_size = 100_000
96
97
batch_lines: list[str] = []
98
while n_rows != 0:
99
batch_lines.clear()
100
remaining_rows = batch_size
101
if n_rows is not None:
102
remaining_rows = min(remaining_rows, n_rows)
103
n_rows -= remaining_rows
104
105
while remaining_rows != 0 and (line := x.readline().rstrip()):
106
if isinstance(line, str):
107
batch_lines += [batch_lines]
108
else:
109
batch_lines += [line.decode()]
110
remaining_rows -= 1
111
112
df = pl.Series("lines", batch_lines, pl.String()).to_frame()
113
114
if with_columns is not None:
115
df = df.select(with_columns)
116
if predicate is not None:
117
df = df.filter(predicate)
118
119
yield df
120
121
if remaining_rows != 0:
122
break
123
124
return register_io_source(io_source=generator, schema=schema)
125
126
text = """
127
Hello
128
This is some text
129
It is spread over multiple lines
130
This allows it to read into multiple rows.
131
""".strip()
132
f = io.BytesIO(bytes(text, encoding="utf-8"))
133
134
assert_series_equal(
135
scan_lines(f).collect().to_series(),
136
pl.Series("lines", text.splitlines(), pl.String()),
137
)
138
139
140
@pytest.mark.may_fail_cloud
141
@pytest.mark.may_fail_auto_streaming # IO plugin validate=False schema mismatch
142
def test_datetime_io_predicate_pushdown_21790() -> None:
143
recorded: dict[str, pl.Expr | None] = {"predicate": None}
144
df = pl.DataFrame(
145
{
146
"timestamp": [
147
datetime.datetime(2024, 1, 1, 0),
148
datetime.datetime(2024, 1, 3, 0),
149
]
150
}
151
)
152
153
def _source(
154
with_columns: list[str] | None,
155
predicate: pl.Expr | None,
156
n_rows: int | None,
157
batch_size: int | None,
158
) -> Iterator[pl.DataFrame]:
159
# capture the predicate passed in
160
recorded["predicate"] = predicate
161
inner_df = df.clone()
162
if with_columns is not None:
163
inner_df = inner_df.select(with_columns)
164
if predicate is not None:
165
inner_df = inner_df.filter(predicate)
166
167
yield inner_df
168
169
schema = {"timestamp": pl.Datetime(time_unit="ns")}
170
lf = register_io_source(io_source=_source, schema=schema)
171
172
cutoff = datetime.datetime(2024, 1, 4)
173
expr = pl.col("timestamp") < cutoff
174
filtered_df = lf.filter(expr).collect()
175
176
pushed_predicate = recorded["predicate"]
177
assert pushed_predicate is not None
178
assert_series_equal(filtered_df.to_series(), df.filter(expr).to_series())
179
180
# check the expression directly
181
dt_val, column_cast = pushed_predicate.meta.pop()
182
# Extract the datetime value from the expression
183
assert pl.DataFrame({}).select(dt_val).item() == cutoff
184
185
column = column_cast.meta.pop()[0]
186
assert column.meta == pl.col("timestamp")
187
188
189
@pytest.mark.parametrize(("validate"), [(True), (False)])
190
def test_reordered_columns_22731(validate: bool) -> None:
191
def my_scan() -> pl.LazyFrame:
192
schema = pl.Schema({"a": pl.Int64, "b": pl.Int64})
193
194
def source_generator(
195
with_columns: list[str] | None,
196
predicate: pl.Expr | None,
197
n_rows: int | None,
198
batch_size: int | None,
199
) -> Iterator[pl.DataFrame]:
200
df = pl.DataFrame({"a": [1, 2, 3], "b": [42, 13, 37]})
201
202
if n_rows is not None:
203
df = df.head(min(n_rows, df.height))
204
205
maxrows = 1
206
if batch_size is not None:
207
maxrows = batch_size
208
209
while df.height > 0:
210
maxrows = min(maxrows, df.height)
211
cur = df.head(maxrows)
212
df = df.slice(maxrows)
213
214
if predicate is not None:
215
cur = cur.filter(predicate)
216
if with_columns is not None:
217
cur = cur.select(with_columns)
218
219
yield cur
220
221
return register_io_source(
222
io_source=source_generator, schema=schema, validate_schema=validate
223
)
224
225
expected_select = pl.DataFrame({"b": [42, 13, 37], "a": [1, 2, 3]})
226
assert_frame_equal(my_scan().select("b", "a").collect(), expected_select)
227
228
expected_ri = pl.DataFrame({"b": [42, 13, 37], "a": [1, 2, 3]}).with_row_index()
229
assert_frame_equal(
230
my_scan().select("b", "a").with_row_index().collect(),
231
expected_ri,
232
)
233
234
expected_with_columns = pl.DataFrame({"a": [1, 2, 3], "b": [42, 13, 37]})
235
assert_frame_equal(
236
my_scan().with_columns("b", "a").collect(), expected_with_columns
237
)
238
239
240
def test_io_plugin_reentrant_deadlock() -> None:
241
out = subprocess.check_output(
242
[
243
sys.executable,
244
"-c",
245
"""\
246
from __future__ import annotations
247
248
import os
249
import sys
250
251
os.environ["POLARS_MAX_THREADS"] = "1"
252
253
import polars as pl
254
from polars.io.plugins import register_io_source
255
256
assert pl.thread_pool_size() == 1
257
258
n = 3
259
i = 0
260
261
262
def reentrant(
263
with_columns: list[str] | None,
264
predicate: pl.Expr | None,
265
n_rows: int | None,
266
batch_size: int | None,
267
):
268
global i
269
270
df = pl.DataFrame({"x": 1})
271
272
if i < n:
273
i += 1
274
yield register_io_source(io_source=reentrant, schema={"x": pl.Int64}).collect()
275
276
yield df
277
278
279
register_io_source(io_source=reentrant, schema={"x": pl.Int64}).collect()
280
281
print("OK", end="", file=sys.stderr)
282
""",
283
],
284
stderr=subprocess.STDOUT,
285
timeout=7,
286
)
287
288
assert out == b"OK"
289
290
291
def test_io_plugin_categorical_24172() -> None:
292
schema = {"cat": pl.Categorical}
293
294
df = pl.concat(
295
[
296
pl.DataFrame({"cat": ["X", "Y"]}, schema=schema),
297
pl.DataFrame({"cat": ["X", "Y"]}, schema=schema),
298
],
299
rechunk=False,
300
)
301
302
assert df.n_chunks() == 2
303
304
assert_frame_equal(
305
register_io_source(lambda *_: iter([df]), schema=df.schema).collect(),
306
df,
307
)
308
309