Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/io/test_scan_options.py
6939 views
1
from __future__ import annotations
2
3
import io
4
from datetime import datetime
5
from typing import IO, Any, Callable
6
from zoneinfo import ZoneInfo
7
8
import pytest
9
10
import polars as pl
11
from polars.testing import assert_frame_equal
12
13
14
@pytest.mark.parametrize(
15
("literal_values", "expected", "cast_options"),
16
[
17
(
18
(pl.lit(1, dtype=pl.Int64), pl.lit(2, dtype=pl.Int32)),
19
pl.Series([1, 2], dtype=pl.Int64),
20
pl.ScanCastOptions(integer_cast="upcast"),
21
),
22
(
23
(pl.lit(1.0, dtype=pl.Float64), pl.lit(2.0, dtype=pl.Float32)),
24
pl.Series([1, 2], dtype=pl.Float64),
25
pl.ScanCastOptions(float_cast="upcast"),
26
),
27
(
28
(pl.lit(1.0, dtype=pl.Float32), pl.lit(2.0, dtype=pl.Float64)),
29
pl.Series([1, 2], dtype=pl.Float32),
30
pl.ScanCastOptions(float_cast=["upcast", "downcast"]),
31
),
32
(
33
(
34
pl.lit(datetime(2025, 1, 1), dtype=pl.Datetime(time_unit="ms")),
35
pl.lit(datetime(2025, 1, 2), dtype=pl.Datetime(time_unit="ns")),
36
),
37
pl.Series(
38
[datetime(2025, 1, 1), datetime(2025, 1, 2)],
39
dtype=pl.Datetime(time_unit="ms"),
40
),
41
pl.ScanCastOptions(datetime_cast="nanosecond-downcast"),
42
),
43
(
44
(
45
pl.lit(
46
datetime(2025, 1, 1, tzinfo=ZoneInfo("Europe/Amsterdam")),
47
dtype=pl.Datetime(time_unit="ms", time_zone="Europe/Amsterdam"),
48
),
49
pl.lit(
50
datetime(2025, 1, 2, tzinfo=ZoneInfo("Australia/Sydney")),
51
dtype=pl.Datetime(time_unit="ns", time_zone="Australia/Sydney"),
52
),
53
),
54
pl.Series(
55
[
56
datetime(2025, 1, 1, tzinfo=ZoneInfo("Europe/Amsterdam")),
57
datetime(2025, 1, 1, 14, tzinfo=ZoneInfo("Europe/Amsterdam")),
58
],
59
dtype=pl.Datetime(time_unit="ms", time_zone="Europe/Amsterdam"),
60
),
61
pl.ScanCastOptions(
62
datetime_cast=["nanosecond-downcast", "convert-timezone"]
63
),
64
),
65
(
66
( # We also test nested primitive upcast policy with this one
67
pl.lit(
68
{"a": [[1]], "b": 1},
69
dtype=pl.Struct(
70
{"a": pl.List(pl.Array(pl.Int32, 1)), "b": pl.Int32}
71
),
72
),
73
pl.lit(
74
{"a": [[2]]},
75
dtype=pl.Struct({"a": pl.List(pl.Array(pl.Int8, 1))}),
76
),
77
),
78
pl.Series(
79
[{"a": [[1]], "b": 1}, {"a": [[2]], "b": None}],
80
dtype=pl.Struct({"a": pl.List(pl.Array(pl.Int32, 1)), "b": pl.Int32}),
81
),
82
pl.ScanCastOptions(
83
integer_cast="upcast",
84
missing_struct_fields="insert",
85
),
86
),
87
(
88
( # Test same set of struct fields but in different order
89
pl.lit(
90
{"a": [[1]], "b": 1},
91
dtype=pl.Struct(
92
{"a": pl.List(pl.Array(pl.Int32, 1)), "b": pl.Int32}
93
),
94
),
95
pl.lit(
96
{"b": None, "a": [[2]]},
97
dtype=pl.Struct(
98
{"b": pl.Int32, "a": pl.List(pl.Array(pl.Int32, 1))}
99
),
100
),
101
),
102
pl.Series(
103
[{"a": [[1]], "b": 1}, {"a": [[2]], "b": None}],
104
dtype=pl.Struct({"a": pl.List(pl.Array(pl.Int32, 1)), "b": pl.Int32}),
105
),
106
None,
107
),
108
# Test logical (datetime) type under list
109
(
110
(
111
pl.lit(
112
[
113
{
114
"field": datetime(
115
2025, 1, 1, tzinfo=ZoneInfo("Europe/Amsterdam")
116
)
117
}
118
],
119
dtype=pl.List(
120
pl.Struct(
121
{
122
"field": pl.Datetime(
123
time_unit="ms", time_zone="Europe/Amsterdam"
124
)
125
}
126
)
127
),
128
),
129
pl.lit(
130
[
131
{
132
"field": datetime(
133
2025, 1, 2, tzinfo=ZoneInfo("Australia/Sydney")
134
)
135
}
136
],
137
dtype=pl.List(
138
pl.Struct(
139
{
140
"field": pl.Datetime(
141
time_unit="ns", time_zone="Australia/Sydney"
142
)
143
}
144
)
145
),
146
),
147
),
148
pl.Series(
149
[
150
[
151
{
152
"field": datetime(
153
2025, 1, 1, tzinfo=ZoneInfo("Europe/Amsterdam")
154
)
155
}
156
],
157
[
158
{
159
"field": datetime(
160
2025, 1, 1, 14, tzinfo=ZoneInfo("Europe/Amsterdam")
161
)
162
}
163
],
164
],
165
dtype=pl.List(
166
pl.Struct(
167
{
168
"field": pl.Datetime(
169
time_unit="ms", time_zone="Europe/Amsterdam"
170
)
171
}
172
)
173
),
174
),
175
pl.ScanCastOptions(
176
datetime_cast=["nanosecond-downcast", "convert-timezone"]
177
),
178
),
179
(
180
(
181
pl.lit(
182
[
183
{
184
"field": datetime(
185
2025, 1, 1, tzinfo=ZoneInfo("Europe/Amsterdam")
186
)
187
}
188
],
189
dtype=pl.Array(
190
pl.Struct(
191
{
192
"field": pl.Datetime(
193
time_unit="ms", time_zone="Europe/Amsterdam"
194
)
195
}
196
),
197
shape=1,
198
),
199
),
200
pl.lit(
201
[
202
{
203
"field": datetime(
204
2025, 1, 2, tzinfo=ZoneInfo("Australia/Sydney")
205
)
206
}
207
],
208
dtype=pl.Array(
209
pl.Struct(
210
{
211
"field": pl.Datetime(
212
time_unit="ns", time_zone="Australia/Sydney"
213
)
214
}
215
),
216
shape=1,
217
),
218
),
219
),
220
pl.Series(
221
[
222
[
223
{
224
"field": datetime(
225
2025, 1, 1, tzinfo=ZoneInfo("Europe/Amsterdam")
226
)
227
}
228
],
229
[
230
{
231
"field": datetime(
232
2025, 1, 1, 14, tzinfo=ZoneInfo("Europe/Amsterdam")
233
)
234
}
235
],
236
],
237
dtype=pl.Array(
238
pl.Struct(
239
{
240
"field": pl.Datetime(
241
time_unit="ms", time_zone="Europe/Amsterdam"
242
)
243
}
244
),
245
shape=1,
246
),
247
),
248
pl.ScanCastOptions(
249
datetime_cast=["nanosecond-downcast", "convert-timezone"]
250
),
251
),
252
# Test outer validity
253
(
254
(
255
pl.lit(
256
None,
257
dtype=pl.List(
258
pl.Struct(
259
{
260
"field": pl.Datetime(
261
time_unit="ms", time_zone="Europe/Amsterdam"
262
)
263
}
264
)
265
),
266
),
267
pl.lit(
268
[None],
269
dtype=pl.List(
270
pl.Struct(
271
{
272
"field": pl.Datetime(
273
time_unit="ns", time_zone="Australia/Sydney"
274
)
275
}
276
)
277
),
278
),
279
),
280
pl.Series(
281
[None, [None]],
282
dtype=pl.List(
283
pl.Struct(
284
{
285
"field": pl.Datetime(
286
time_unit="ms", time_zone="Europe/Amsterdam"
287
)
288
}
289
)
290
),
291
),
292
pl.ScanCastOptions(
293
datetime_cast=["nanosecond-downcast", "convert-timezone"]
294
),
295
),
296
(
297
(
298
pl.lit(
299
None,
300
dtype=pl.Array(
301
pl.Struct(
302
{
303
"field": pl.Datetime(
304
time_unit="ms", time_zone="Europe/Amsterdam"
305
)
306
}
307
),
308
shape=1,
309
),
310
),
311
pl.lit(
312
[None],
313
dtype=pl.Array(
314
pl.Struct(
315
{
316
"field": pl.Datetime(
317
time_unit="ns", time_zone="Australia/Sydney"
318
)
319
}
320
),
321
shape=1,
322
),
323
),
324
),
325
pl.Series(
326
[None, [None]],
327
dtype=pl.Array(
328
pl.Struct(
329
{
330
"field": pl.Datetime(
331
time_unit="ms", time_zone="Europe/Amsterdam"
332
)
333
}
334
),
335
shape=1,
336
),
337
),
338
pl.ScanCastOptions(
339
datetime_cast=["nanosecond-downcast", "convert-timezone"]
340
),
341
),
342
],
343
)
344
def test_scan_cast_options(
345
literal_values: tuple[pl.Expr, pl.Expr],
346
expected: pl.Series,
347
cast_options: pl.ScanCastOptions | None,
348
) -> None:
349
expected = expected.alias("literal")
350
lv1, lv2 = literal_values
351
352
df1 = pl.select(lv1)
353
df2 = pl.select(lv2)
354
355
# `cast()` from the Python API should give the same results.
356
assert_frame_equal(
357
pl.concat(
358
[
359
df1.cast(expected.dtype),
360
df2.cast(expected.dtype),
361
]
362
),
363
expected.to_frame(),
364
)
365
366
files: list[IO[bytes]] = [io.BytesIO(), io.BytesIO()]
367
368
df1.write_parquet(files[0])
369
df2.write_parquet(files[1])
370
371
for f in files:
372
f.seek(0)
373
374
# Note: Schema is taken from the first file
375
376
if cast_options is not None:
377
q = pl.scan_parquet(files)
378
379
with pytest.raises(pl.exceptions.SchemaError, match=r"hint: .*pass"):
380
q.collect()
381
382
assert_frame_equal(
383
pl.scan_parquet(files, cast_options=cast_options).collect(),
384
expected.to_frame(),
385
)
386
387
388
def test_scan_cast_options_forbid_int_downcast() -> None:
389
# Test to ensure that passing `integer_cast='upcast'` does not accidentally
390
# permit casting to smaller integer types.
391
lv1, lv2 = pl.lit(1, dtype=pl.Int8), pl.lit(2, dtype=pl.Int32)
392
393
files: list[IO[bytes]] = [io.BytesIO(), io.BytesIO()]
394
395
df1 = pl.select(lv1)
396
df2 = pl.select(lv2)
397
398
df1.write_parquet(files[0])
399
df2.write_parquet(files[1])
400
401
for f in files:
402
f.seek(0)
403
404
q = pl.scan_parquet(files)
405
406
with pytest.raises(pl.exceptions.SchemaError):
407
q.collect()
408
409
for f in files:
410
f.seek(0)
411
412
q = pl.scan_parquet(
413
files,
414
cast_options=pl.ScanCastOptions(integer_cast="upcast"),
415
)
416
417
with pytest.raises(pl.exceptions.SchemaError):
418
q.collect()
419
420
421
def test_scan_cast_options_extra_struct_fields() -> None:
422
cast_options = pl.ScanCastOptions(extra_struct_fields="ignore")
423
424
expected = pl.Series([{"a": 1}, {"a": 2}], dtype=pl.Struct({"a": pl.Int32}))
425
expected = expected.alias("literal")
426
427
lv1, lv2 = (
428
pl.lit({"a": 1}, dtype=pl.Struct({"a": pl.Int32})),
429
pl.lit(
430
{"a": 2, "extra_field": 1},
431
dtype=pl.Struct({"a": pl.Int32, "extra_field": pl.Int32}),
432
),
433
)
434
435
files: list[IO[bytes]] = [io.BytesIO(), io.BytesIO()]
436
437
df1 = pl.select(lv1)
438
df2 = pl.select(lv2)
439
440
df1.write_parquet(files[0])
441
df2.write_parquet(files[1])
442
443
for f in files:
444
f.seek(0)
445
446
q = pl.scan_parquet(files)
447
448
with pytest.raises(pl.exceptions.SchemaError, match=r"hint: specify .*or pass"):
449
q.collect()
450
451
assert_frame_equal(
452
pl.scan_parquet(files, cast_options=cast_options).collect(),
453
expected.to_frame(),
454
)
455
456
457
def test_cast_options_ignore_extra_columns() -> None:
458
files: list[IO[bytes]] = [io.BytesIO(), io.BytesIO()]
459
460
pl.DataFrame({"a": 1}).write_parquet(files[0])
461
pl.DataFrame({"a": 2, "b": 1}).write_parquet(files[1])
462
463
with pytest.raises(
464
pl.exceptions.SchemaError,
465
match="extra column in file outside of expected schema: b, hint: specify.* or pass",
466
):
467
pl.scan_parquet(files, schema={"a": pl.Int64}).collect()
468
469
assert_frame_equal(
470
pl.scan_parquet(
471
files,
472
schema={"a": pl.Int64},
473
extra_columns="ignore",
474
).collect(),
475
pl.DataFrame({"a": [1, 2]}),
476
)
477
478
479
@pytest.mark.parametrize(
480
("scan_func", "write_func"),
481
[
482
(pl.scan_parquet, pl.DataFrame.write_parquet),
483
# TODO: Fix for all other formats
484
# (pl.scan_ipc, pl.DataFrame.write_ipc),
485
# (pl.scan_csv, pl.DataFrame.write_csv),
486
# (pl.scan_ndjson, pl.DataFrame.write_ndjson),
487
],
488
)
489
def test_scan_extra_columns(
490
scan_func: Callable[[Any], pl.LazyFrame],
491
write_func: Callable[[pl.DataFrame, io.BytesIO], None],
492
) -> None:
493
dfs = [pl.DataFrame({"a": 1, "b": 1}), pl.DataFrame({"a": 2, "b": 2, "c": 2})]
494
files = [io.BytesIO(), io.BytesIO()]
495
496
write_func(dfs[0], files[0])
497
write_func(dfs[1], files[1])
498
499
with pytest.raises(
500
pl.exceptions.SchemaError,
501
match=r"extra column in file outside of expected schema: c, hint: ",
502
):
503
scan_func(files).collect()
504
505
assert_frame_equal(
506
scan_func(files, extra_columns="ignore").collect(), # type: ignore[call-arg]
507
pl.DataFrame({"a": [1, 2], "b": [1, 2]}),
508
)
509
510