Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/io/test_scan_options.py
8424 views
1
from __future__ import annotations
2
3
import io
4
from datetime import datetime
5
from typing import IO, TYPE_CHECKING, Any
6
from zoneinfo import ZoneInfo
7
8
import pytest
9
10
import polars as pl
11
from polars.datatypes.group import FLOAT_DTYPES
12
from polars.exceptions import SchemaError
13
from polars.testing import assert_frame_equal
14
15
if TYPE_CHECKING:
16
from collections.abc import Callable
17
18
19
@pytest.mark.parametrize(
20
("literal_values", "expected", "cast_options"),
21
[
22
(
23
(pl.lit(1, dtype=pl.Int64), pl.lit(2, dtype=pl.Int32)),
24
pl.Series([1, 2], dtype=pl.Int64),
25
pl.ScanCastOptions(integer_cast="upcast"),
26
),
27
(
28
(pl.lit(1.0, dtype=pl.Float64), pl.lit(2.0, dtype=pl.Float32)),
29
pl.Series([1, 2], dtype=pl.Float64),
30
pl.ScanCastOptions(float_cast="upcast"),
31
),
32
(
33
(pl.lit(1.0, dtype=pl.Float32), pl.lit(2.0, dtype=pl.Float64)),
34
pl.Series([1, 2], dtype=pl.Float32),
35
pl.ScanCastOptions(float_cast=["upcast", "downcast"]),
36
),
37
(
38
(
39
pl.lit(datetime(2025, 1, 1), dtype=pl.Datetime(time_unit="ms")),
40
pl.lit(datetime(2025, 1, 2), dtype=pl.Datetime(time_unit="ns")),
41
),
42
pl.Series(
43
[datetime(2025, 1, 1), datetime(2025, 1, 2)],
44
dtype=pl.Datetime(time_unit="ms"),
45
),
46
pl.ScanCastOptions(datetime_cast="nanosecond-downcast"),
47
),
48
(
49
(
50
pl.lit(
51
datetime(2025, 1, 1, tzinfo=ZoneInfo("Europe/Amsterdam")),
52
dtype=pl.Datetime(time_unit="ms", time_zone="Europe/Amsterdam"),
53
),
54
pl.lit(
55
datetime(2025, 1, 2, tzinfo=ZoneInfo("Australia/Sydney")),
56
dtype=pl.Datetime(time_unit="ns", time_zone="Australia/Sydney"),
57
),
58
),
59
pl.Series(
60
[
61
datetime(2025, 1, 1, tzinfo=ZoneInfo("Europe/Amsterdam")),
62
datetime(2025, 1, 1, 14, tzinfo=ZoneInfo("Europe/Amsterdam")),
63
],
64
dtype=pl.Datetime(time_unit="ms", time_zone="Europe/Amsterdam"),
65
),
66
pl.ScanCastOptions(
67
datetime_cast=["nanosecond-downcast", "convert-timezone"]
68
),
69
),
70
(
71
( # We also test nested primitive upcast policy with this one
72
pl.lit(
73
{"a": [[1]], "b": 1},
74
dtype=pl.Struct(
75
{"a": pl.List(pl.Array(pl.Int32, 1)), "b": pl.Int32}
76
),
77
),
78
pl.lit(
79
{"a": [[2]]},
80
dtype=pl.Struct({"a": pl.List(pl.Array(pl.Int8, 1))}),
81
),
82
),
83
pl.Series(
84
[{"a": [[1]], "b": 1}, {"a": [[2]], "b": None}],
85
dtype=pl.Struct({"a": pl.List(pl.Array(pl.Int32, 1)), "b": pl.Int32}),
86
),
87
pl.ScanCastOptions(
88
integer_cast="upcast",
89
missing_struct_fields="insert",
90
),
91
),
92
(
93
( # Test same set of struct fields but in different order
94
pl.lit(
95
{"a": [[1]], "b": 1},
96
dtype=pl.Struct(
97
{"a": pl.List(pl.Array(pl.Int32, 1)), "b": pl.Int32}
98
),
99
),
100
pl.lit(
101
{"b": None, "a": [[2]]},
102
dtype=pl.Struct(
103
{"b": pl.Int32, "a": pl.List(pl.Array(pl.Int32, 1))}
104
),
105
),
106
),
107
pl.Series(
108
[{"a": [[1]], "b": 1}, {"a": [[2]], "b": None}],
109
dtype=pl.Struct({"a": pl.List(pl.Array(pl.Int32, 1)), "b": pl.Int32}),
110
),
111
None,
112
),
113
# Test logical (datetime) type under list
114
(
115
(
116
pl.lit(
117
[
118
{
119
"field": datetime(
120
2025, 1, 1, tzinfo=ZoneInfo("Europe/Amsterdam")
121
)
122
}
123
],
124
dtype=pl.List(
125
pl.Struct(
126
{
127
"field": pl.Datetime(
128
time_unit="ms", time_zone="Europe/Amsterdam"
129
)
130
}
131
)
132
),
133
),
134
pl.lit(
135
[
136
{
137
"field": datetime(
138
2025, 1, 2, tzinfo=ZoneInfo("Australia/Sydney")
139
)
140
}
141
],
142
dtype=pl.List(
143
pl.Struct(
144
{
145
"field": pl.Datetime(
146
time_unit="ns", time_zone="Australia/Sydney"
147
)
148
}
149
)
150
),
151
),
152
),
153
pl.Series(
154
[
155
[
156
{
157
"field": datetime(
158
2025, 1, 1, tzinfo=ZoneInfo("Europe/Amsterdam")
159
)
160
}
161
],
162
[
163
{
164
"field": datetime(
165
2025, 1, 1, 14, tzinfo=ZoneInfo("Europe/Amsterdam")
166
)
167
}
168
],
169
],
170
dtype=pl.List(
171
pl.Struct(
172
{
173
"field": pl.Datetime(
174
time_unit="ms", time_zone="Europe/Amsterdam"
175
)
176
}
177
)
178
),
179
),
180
pl.ScanCastOptions(
181
datetime_cast=["nanosecond-downcast", "convert-timezone"]
182
),
183
),
184
(
185
(
186
pl.lit(
187
[
188
{
189
"field": datetime(
190
2025, 1, 1, tzinfo=ZoneInfo("Europe/Amsterdam")
191
)
192
}
193
],
194
dtype=pl.Array(
195
pl.Struct(
196
{
197
"field": pl.Datetime(
198
time_unit="ms", time_zone="Europe/Amsterdam"
199
)
200
}
201
),
202
shape=1,
203
),
204
),
205
pl.lit(
206
[
207
{
208
"field": datetime(
209
2025, 1, 2, tzinfo=ZoneInfo("Australia/Sydney")
210
)
211
}
212
],
213
dtype=pl.Array(
214
pl.Struct(
215
{
216
"field": pl.Datetime(
217
time_unit="ns", time_zone="Australia/Sydney"
218
)
219
}
220
),
221
shape=1,
222
),
223
),
224
),
225
pl.Series(
226
[
227
[
228
{
229
"field": datetime(
230
2025, 1, 1, tzinfo=ZoneInfo("Europe/Amsterdam")
231
)
232
}
233
],
234
[
235
{
236
"field": datetime(
237
2025, 1, 1, 14, tzinfo=ZoneInfo("Europe/Amsterdam")
238
)
239
}
240
],
241
],
242
dtype=pl.Array(
243
pl.Struct(
244
{
245
"field": pl.Datetime(
246
time_unit="ms", time_zone="Europe/Amsterdam"
247
)
248
}
249
),
250
shape=1,
251
),
252
),
253
pl.ScanCastOptions(
254
datetime_cast=["nanosecond-downcast", "convert-timezone"]
255
),
256
),
257
# Test outer validity
258
(
259
(
260
pl.lit(
261
None,
262
dtype=pl.List(
263
pl.Struct(
264
{
265
"field": pl.Datetime(
266
time_unit="ms", time_zone="Europe/Amsterdam"
267
)
268
}
269
)
270
),
271
),
272
pl.lit(
273
[None],
274
dtype=pl.List(
275
pl.Struct(
276
{
277
"field": pl.Datetime(
278
time_unit="ns", time_zone="Australia/Sydney"
279
)
280
}
281
)
282
),
283
),
284
),
285
pl.Series(
286
[None, [None]],
287
dtype=pl.List(
288
pl.Struct(
289
{
290
"field": pl.Datetime(
291
time_unit="ms", time_zone="Europe/Amsterdam"
292
)
293
}
294
)
295
),
296
),
297
pl.ScanCastOptions(
298
datetime_cast=["nanosecond-downcast", "convert-timezone"]
299
),
300
),
301
(
302
(
303
pl.lit(
304
None,
305
dtype=pl.Array(
306
pl.Struct(
307
{
308
"field": pl.Datetime(
309
time_unit="ms", time_zone="Europe/Amsterdam"
310
)
311
}
312
),
313
shape=1,
314
),
315
),
316
pl.lit(
317
[None],
318
dtype=pl.Array(
319
pl.Struct(
320
{
321
"field": pl.Datetime(
322
time_unit="ns", time_zone="Australia/Sydney"
323
)
324
}
325
),
326
shape=1,
327
),
328
),
329
),
330
pl.Series(
331
[None, [None]],
332
dtype=pl.Array(
333
pl.Struct(
334
{
335
"field": pl.Datetime(
336
time_unit="ms", time_zone="Europe/Amsterdam"
337
)
338
}
339
),
340
shape=1,
341
),
342
),
343
pl.ScanCastOptions(
344
datetime_cast=["nanosecond-downcast", "convert-timezone"]
345
),
346
),
347
],
348
)
349
def test_scan_cast_options(
350
literal_values: tuple[pl.Expr, pl.Expr],
351
expected: pl.Series,
352
cast_options: pl.ScanCastOptions | None,
353
) -> None:
354
expected = expected.alias("literal")
355
lv1, lv2 = literal_values
356
357
df1 = pl.select(lv1)
358
df2 = pl.select(lv2)
359
360
# `cast()` from the Python API should give the same results.
361
assert_frame_equal(
362
pl.concat(
363
[
364
df1.cast(expected.dtype),
365
df2.cast(expected.dtype),
366
]
367
),
368
expected.to_frame(),
369
)
370
371
files: list[IO[bytes]] = [io.BytesIO(), io.BytesIO()]
372
373
df1.write_parquet(files[0])
374
df2.write_parquet(files[1])
375
376
for f in files:
377
f.seek(0)
378
379
# Note: Schema is taken from the first file
380
381
if cast_options is not None:
382
q = pl.scan_parquet(files)
383
384
with pytest.raises(pl.exceptions.SchemaError, match=r"hint: .*pass"):
385
q.collect()
386
387
assert_frame_equal(
388
pl.scan_parquet(files, cast_options=cast_options).collect(),
389
expected.to_frame(),
390
)
391
392
393
def test_scan_cast_options_forbid_int_downcast() -> None:
394
# Test to ensure that passing `integer_cast='upcast'` does not accidentally
395
# permit casting to smaller integer types.
396
lv1, lv2 = pl.lit(1, dtype=pl.Int8), pl.lit(2, dtype=pl.Int32)
397
398
files: list[IO[bytes]] = [io.BytesIO(), io.BytesIO()]
399
400
df1 = pl.select(lv1)
401
df2 = pl.select(lv2)
402
403
df1.write_parquet(files[0])
404
df2.write_parquet(files[1])
405
406
for f in files:
407
f.seek(0)
408
409
q = pl.scan_parquet(files)
410
411
with pytest.raises(pl.exceptions.SchemaError):
412
q.collect()
413
414
for f in files:
415
f.seek(0)
416
417
q = pl.scan_parquet(
418
files,
419
cast_options=pl.ScanCastOptions(integer_cast="upcast"),
420
)
421
422
with pytest.raises(pl.exceptions.SchemaError):
423
q.collect()
424
425
426
def test_scan_cast_options_extra_struct_fields() -> None:
427
cast_options = pl.ScanCastOptions(extra_struct_fields="ignore")
428
429
expected = pl.Series([{"a": 1}, {"a": 2}], dtype=pl.Struct({"a": pl.Int32}))
430
expected = expected.alias("literal")
431
432
lv1, lv2 = (
433
pl.lit({"a": 1}, dtype=pl.Struct({"a": pl.Int32})),
434
pl.lit(
435
{"a": 2, "extra_field": 1},
436
dtype=pl.Struct({"a": pl.Int32, "extra_field": pl.Int32}),
437
),
438
)
439
440
files: list[IO[bytes]] = [io.BytesIO(), io.BytesIO()]
441
442
df1 = pl.select(lv1)
443
df2 = pl.select(lv2)
444
445
df1.write_parquet(files[0])
446
df2.write_parquet(files[1])
447
448
for f in files:
449
f.seek(0)
450
451
q = pl.scan_parquet(files)
452
453
with pytest.raises(pl.exceptions.SchemaError, match=r"hint: specify .*or pass"):
454
q.collect()
455
456
assert_frame_equal(
457
pl.scan_parquet(files, cast_options=cast_options).collect(),
458
expected.to_frame(),
459
)
460
461
462
def test_cast_options_ignore_extra_columns() -> None:
463
files: list[IO[bytes]] = [io.BytesIO(), io.BytesIO()]
464
465
pl.DataFrame({"a": 1}).write_parquet(files[0])
466
pl.DataFrame({"a": 2, "b": 1}).write_parquet(files[1])
467
468
with pytest.raises(
469
pl.exceptions.SchemaError,
470
match=r"extra column in file outside of expected schema: b, hint: specify.* or pass",
471
):
472
pl.scan_parquet(files, schema={"a": pl.Int64}).collect()
473
474
assert_frame_equal(
475
pl.scan_parquet(
476
files,
477
schema={"a": pl.Int64},
478
extra_columns="ignore",
479
).collect(),
480
pl.DataFrame({"a": [1, 2]}),
481
)
482
483
484
@pytest.mark.parametrize(
485
("scan_func", "write_func"),
486
[
487
(pl.scan_parquet, pl.DataFrame.write_parquet),
488
# TODO: Fix for all other formats
489
# (pl.scan_ipc, pl.DataFrame.write_ipc),
490
# (pl.scan_csv, pl.DataFrame.write_csv),
491
# (pl.scan_ndjson, pl.DataFrame.write_ndjson),
492
],
493
)
494
def test_scan_cast_options_extra_columns(
495
scan_func: Callable[[Any], pl.LazyFrame],
496
write_func: Callable[[pl.DataFrame, io.BytesIO], None],
497
) -> None:
498
dfs = [pl.DataFrame({"a": 1, "b": 1}), pl.DataFrame({"a": 2, "b": 2, "c": 2})]
499
files = [io.BytesIO(), io.BytesIO()]
500
501
write_func(dfs[0], files[0])
502
write_func(dfs[1], files[1])
503
504
with pytest.raises(
505
pl.exceptions.SchemaError,
506
match=r"extra column in file outside of expected schema: c, hint: ",
507
):
508
scan_func(files).collect()
509
510
assert_frame_equal(
511
scan_func(files, extra_columns="ignore").collect(), # type: ignore[call-arg]
512
pl.DataFrame({"a": [1, 2], "b": [1, 2]}),
513
)
514
515
516
@pytest.mark.parametrize("float_dtype", sorted(FLOAT_DTYPES, key=repr))
517
def test_scan_cast_options_integer_to_float(float_dtype: pl.DataType) -> None:
518
df = pl.DataFrame({"a": [1]}, schema={"a": pl.Int64})
519
f = io.BytesIO()
520
df.write_parquet(f)
521
522
f.seek(0)
523
524
assert_frame_equal(
525
pl.scan_parquet(f).collect(),
526
pl.DataFrame({"a": [1]}, schema={"a": pl.Int64}),
527
)
528
529
q = pl.scan_parquet(f, schema={"a": float_dtype})
530
531
with pytest.raises(SchemaError):
532
q.collect()
533
534
f.seek(0)
535
536
assert_frame_equal(
537
pl.scan_parquet(
538
f,
539
schema={"a": float_dtype},
540
cast_options=pl.ScanCastOptions(integer_cast="allow-float"),
541
).collect(),
542
pl.DataFrame({"a": [1.0]}, schema={"a": float_dtype}),
543
)
544
545