Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/io/test_parquet_field_overwrites.py
6939 views
1
import io
2
3
import pyarrow.parquet as pq
4
import pytest
5
6
import polars as pl
7
from polars.io.parquet import ParquetFieldOverwrites
8
9
10
def test_required_flat() -> None:
11
f = io.BytesIO()
12
pl.Series("a", [1, 2, 3]).to_frame().lazy().sink_parquet(
13
f,
14
field_overwrites=pl.io.parquet.ParquetFieldOverwrites(name="a", required=False),
15
)
16
17
f.seek(0)
18
assert pq.read_schema(f).field(0).nullable
19
20
f.seek(0)
21
pl.Series("a", [1, 2, 3]).to_frame().lazy().sink_parquet(
22
f,
23
field_overwrites=pl.io.parquet.ParquetFieldOverwrites(name="a", required=True),
24
)
25
26
f.truncate()
27
f.seek(0)
28
assert not pq.read_schema(f).field(0).nullable
29
30
f = io.BytesIO()
31
with pytest.raises(pl.exceptions.InvalidOperationError, match="missing value"):
32
pl.Series("a", [1, 2, 3, None]).to_frame().lazy().sink_parquet(
33
f,
34
field_overwrites=pl.io.parquet.ParquetFieldOverwrites(
35
name="a", required=True
36
),
37
)
38
39
40
@pytest.mark.parametrize("dtype", [pl.List(pl.Int64()), pl.Array(pl.Int64(), 1)])
41
def test_required_list(dtype: pl.DataType) -> None:
42
f = io.BytesIO()
43
pl.Series("a", [[1], [2], [3], [None]], dtype).to_frame().lazy().sink_parquet(
44
f,
45
field_overwrites=pl.io.parquet.ParquetFieldOverwrites(name="a", required=True),
46
)
47
f.seek(0)
48
schema = pq.read_schema(f)
49
assert not schema.field(0).nullable
50
assert schema.field(0).type.value_field.nullable
51
52
with pytest.raises(pl.exceptions.InvalidOperationError, match="missing value"):
53
pl.Series("a", [[1], [2], [3], None], dtype).to_frame().lazy().sink_parquet(
54
io.BytesIO(),
55
field_overwrites=pl.io.parquet.ParquetFieldOverwrites(
56
name="a", required=True
57
),
58
)
59
60
with pytest.raises(pl.exceptions.InvalidOperationError, match="missing value"):
61
pl.Series("a", [[1], [2], [3], [None]], dtype).to_frame().lazy().sink_parquet(
62
io.BytesIO(),
63
field_overwrites=pl.io.parquet.ParquetFieldOverwrites(
64
name="a",
65
required=True,
66
children=pl.io.parquet.ParquetFieldOverwrites(required=True),
67
),
68
)
69
70
f = io.BytesIO()
71
pl.Series("a", [[1], [2], [3], [4]], dtype).to_frame().lazy().sink_parquet(
72
f,
73
field_overwrites=pl.io.parquet.ParquetFieldOverwrites(
74
name="a",
75
required=True,
76
children=pl.io.parquet.ParquetFieldOverwrites(required=True),
77
),
78
)
79
f.seek(0)
80
schema = pq.read_schema(f)
81
assert not schema.field(0).nullable
82
assert not schema.field(0).type.value_field.nullable
83
84
85
def test_required_struct() -> None:
86
f = io.BytesIO()
87
pl.Series(
88
"a", [{"x": 1}, {"x": 2}, {"x": 3}, {"x": 4}]
89
).to_frame().lazy().sink_parquet(
90
f,
91
field_overwrites=pl.io.parquet.ParquetFieldOverwrites(
92
name="a",
93
required=True,
94
),
95
)
96
f.seek(0)
97
schema = pq.read_schema(f)
98
assert not schema.field(0).nullable
99
assert schema.field(0).type.fields[0].nullable
100
101
f = io.BytesIO()
102
pl.Series(
103
"a", [{"x": 1}, {"x": None}, {"x": 2}, {"x": 3}]
104
).to_frame().lazy().sink_parquet(
105
f,
106
field_overwrites=pl.io.parquet.ParquetFieldOverwrites(
107
name="a",
108
required=True,
109
),
110
)
111
112
f.seek(0)
113
schema = pq.read_schema(f)
114
assert not schema.field(0).nullable
115
assert schema.field(0).type.fields[0].nullable
116
117
with pytest.raises(pl.exceptions.InvalidOperationError, match="missing value"):
118
pl.Series(
119
"a", [{"x": 1}, {"x": None}, {"x": 2}, {"x": 3}]
120
).to_frame().lazy().sink_parquet(
121
io.BytesIO(),
122
field_overwrites=ParquetFieldOverwrites(
123
name="a",
124
required=True,
125
children={"x": ParquetFieldOverwrites(required=True)},
126
),
127
)
128
129
f = io.BytesIO()
130
pl.Series(
131
"a", [{"x": 1}, {"x": 2}, {"x": 2}, {"x": 3}]
132
).to_frame().lazy().sink_parquet(
133
f,
134
field_overwrites=ParquetFieldOverwrites(
135
name="a",
136
required=True,
137
children={"x": ParquetFieldOverwrites(required=True)},
138
),
139
)
140
f.seek(0)
141
schema = pq.read_schema(f)
142
assert not schema.field(0).nullable
143
assert not schema.field(0).type.fields[0].nullable
144
145