Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/lazyframe/test_schema.py
8431 views
1
import pickle
2
from datetime import datetime
3
from typing import Any
4
5
import pytest
6
7
import polars as pl
8
from polars.datatypes.group import NUMERIC_DTYPES, TEMPORAL_DTYPES
9
from polars.testing.asserts.frame import assert_frame_equal
10
11
# Used by test_lazy_collect_schema_matches_computed_schema
12
_TEST_COLLECT_SCHEMA_M_DTYPES = sorted(
13
({pl.Boolean, pl.String} | NUMERIC_DTYPES | TEMPORAL_DTYPES) - {pl.Decimal},
14
key=repr,
15
)
16
17
18
def test_schema() -> None:
19
s = pl.Schema({"foo": pl.Int8(), "bar": pl.String()})
20
21
assert s["foo"] == pl.Int8()
22
assert s["bar"] == pl.String()
23
assert s.len() == 2
24
assert s.names() == ["foo", "bar"]
25
assert s.dtypes() == [pl.Int8(), pl.String()]
26
27
with pytest.raises(
28
TypeError,
29
match="dtypes must be fully-specified, got: List",
30
):
31
pl.Schema({"foo": pl.String, "bar": pl.List})
32
33
34
@pytest.mark.parametrize(
35
"schema",
36
[
37
pl.Schema(),
38
pl.Schema({"foo": pl.Int8()}),
39
pl.Schema({"foo": pl.Datetime("us"), "bar": pl.String()}),
40
pl.Schema(
41
{
42
"foo": pl.UInt32(),
43
"bar": pl.Categorical(),
44
"baz": pl.Struct({"x": pl.Int64(), "y": pl.Float64()}),
45
}
46
),
47
],
48
)
49
def test_schema_empty_frame(schema: pl.Schema) -> None:
50
assert_frame_equal(
51
schema.to_frame(),
52
pl.DataFrame(schema=schema),
53
)
54
55
56
def test_schema_equality() -> None:
57
s1 = pl.Schema({"foo": pl.Int8(), "bar": pl.Float64()})
58
s2 = pl.Schema({"foo": pl.Int8(), "bar": pl.String()})
59
s3 = pl.Schema({"bar": pl.Float64(), "foo": pl.Int8()})
60
61
assert s1 == s1
62
assert s2 == s2
63
assert s3 == s3
64
assert s1 != s2
65
assert s1 != s3
66
assert s2 != s3
67
68
s4 = pl.Schema({"foo": pl.Datetime("us"), "bar": pl.Duration("ns")})
69
s5 = pl.Schema({"foo": pl.Datetime("ns"), "bar": pl.Duration("us")})
70
s6 = {"foo": pl.Datetime, "bar": pl.Duration}
71
72
assert s4 != s5
73
assert s4 != s6
74
75
76
def test_schema_parse_python_dtypes() -> None:
77
cardinal_directions = pl.Enum(["north", "south", "east", "west"])
78
79
s = pl.Schema({"foo": pl.List(pl.Int32), "bar": int, "baz": cardinal_directions}) # type: ignore[arg-type]
80
s["ham"] = datetime
81
82
assert s["foo"] == pl.List(pl.Int32)
83
assert s["bar"] == pl.Int64
84
assert s["baz"] == cardinal_directions
85
assert s["ham"] == pl.Datetime("us")
86
87
assert s.len() == 4
88
assert s.names() == ["foo", "bar", "baz", "ham"]
89
assert s.dtypes() == [pl.List, pl.Int64, cardinal_directions, pl.Datetime("us")]
90
91
assert list(s.to_python().values()) == [list, int, str, datetime]
92
assert [tp.to_python() for tp in s.dtypes()] == [list, int, str, datetime]
93
94
95
def test_schema_picklable() -> None:
96
s = pl.Schema(
97
{
98
"foo": pl.Int8(),
99
"bar": pl.String(),
100
"ham": pl.Struct({"x": pl.List(pl.Date)}),
101
}
102
)
103
pickled = pickle.dumps(s)
104
s2 = pickle.loads(pickled)
105
assert s == s2
106
107
108
def test_schema_python() -> None:
109
input = {
110
"foo": pl.Int8(),
111
"bar": pl.String(),
112
"baz": pl.Categorical(),
113
"ham": pl.Object(),
114
"spam": pl.Struct({"time": pl.List(pl.Duration), "dist": pl.Float64}),
115
}
116
expected = {
117
"foo": int,
118
"bar": str,
119
"baz": str,
120
"ham": object,
121
"spam": dict,
122
}
123
for schema in (input, input.items(), list(input.items())):
124
s = pl.Schema(schema)
125
assert expected == s.to_python()
126
127
128
def test_schema_in_map_elements_returns_scalar() -> None:
129
schema = pl.Schema([("portfolio", pl.String()), ("irr", pl.Float64())])
130
131
ldf = pl.LazyFrame(
132
{
133
"portfolio": ["A", "A", "B", "B"],
134
"amounts": [100.0, -110.0] * 2,
135
}
136
)
137
q = ldf.group_by("portfolio").agg(
138
pl.col("amounts")
139
.implode()
140
.map_elements(lambda x: float(x.sum()), return_dtype=pl.Float64)
141
.alias("irr")
142
)
143
assert q.collect_schema() == schema
144
assert q.collect().schema == schema
145
146
147
@pytest.mark.slow
148
@pytest.mark.parametrize(
149
"expr",
150
[
151
# TODO: Add more (bitwise) operators once their types are resolved correctly
152
pl.col("col0") > pl.col("col1"),
153
pl.col("col0") >= pl.col("col1"),
154
pl.col("col0") < pl.col("col1"),
155
pl.col("col0") <= pl.col("col1"),
156
pl.col("col0") == pl.col("col1"),
157
pl.col("col0") != pl.col("col1"),
158
pl.col("col0") + pl.col("col1"),
159
pl.col("col0") - pl.col("col1"),
160
pl.col("col0") * pl.col("col1"),
161
pl.col("col0") / pl.col("col1"),
162
pl.col("col0").truediv(pl.col("col1")),
163
pl.col("col0") // pl.col("col1"),
164
pl.col("col0") % pl.col("col1"),
165
],
166
)
167
@pytest.mark.parametrize("dtype1", _TEST_COLLECT_SCHEMA_M_DTYPES)
168
@pytest.mark.parametrize("dtype2", _TEST_COLLECT_SCHEMA_M_DTYPES)
169
def test_lazy_collect_schema_matches_computed_schema(
170
expr: pl.Expr, dtype1: pl.DataType, dtype2: pl.DataType
171
) -> None:
172
df = pl.DataFrame(
173
{
174
"col0": [None],
175
"col1": [None],
176
},
177
schema={
178
"col0": dtype1,
179
"col1": dtype2,
180
},
181
)
182
lazy_df = df.lazy().select(expr)
183
184
expected_schema = None
185
try:
186
expected_schema = lazy_df.collect().schema
187
except (
188
# Applying the operator to these dtypes will result in an error,
189
# so they their output dtype is undefined
190
pl.exceptions.InvalidOperationError,
191
pl.exceptions.SchemaError,
192
pl.exceptions.ComputeError,
193
):
194
return
195
196
actual_schema = lazy_df.collect_schema()
197
assert actual_schema == expected_schema, (
198
f"{expr} on {df.dtypes} results in {actual_schema} instead of {expected_schema}\n"
199
f"result of computation is:\n{lazy_df.collect()}\n"
200
)
201
202
203
def test_ir_cache_unique_18198() -> None:
204
lf = pl.LazyFrame({"a": [1]})
205
lf.collect_schema()
206
assert pl.concat([lf, lf]).collect().to_dict(as_series=False) == {"a": [1, 1]}
207
208
209
def test_schema_functions_in_agg_with_literal_arg_19011() -> None:
210
q = (
211
pl.LazyFrame({"a": [1, 2, 3, None, 5]})
212
.rolling(index_column=pl.int_range(pl.len()).alias("idx"), period="3i")
213
.agg(pl.col("a").fill_null(0).alias("a_1"), pl.col("a").pow(2.0).alias("a_2"))
214
)
215
assert q.collect_schema() == pl.Schema(
216
[("idx", pl.Int64), ("a_1", pl.List(pl.Int64)), ("a_2", pl.List(pl.Float64))]
217
)
218
219
220
def test_lazy_explode_in_agg_schema_19562() -> None:
221
def new_df_check_schema(
222
value: dict[str, Any], schema: dict[str, Any]
223
) -> pl.DataFrame:
224
df = pl.DataFrame(value)
225
assert df.schema == schema
226
return df
227
228
lf = pl.LazyFrame({"a": [1], "b": [[1]]})
229
230
q = lf.group_by("a").agg(pl.col("b"))
231
schema = {"a": pl.Int64, "b": pl.List(pl.List(pl.Int64))}
232
233
assert q.collect_schema() == schema
234
assert_frame_equal(
235
q.collect(), new_df_check_schema({"a": [1], "b": [[[1]]]}, schema)
236
)
237
238
q = lf.group_by("a").agg(pl.col("b").explode())
239
schema = {"a": pl.Int64, "b": pl.List(pl.Int64)}
240
241
assert q.collect_schema() == schema
242
assert_frame_equal(q.collect(), new_df_check_schema({"a": [1], "b": [[1]]}, schema))
243
244
q = lf.group_by("a").agg(pl.col("b").explode().explode())
245
schema = {"a": pl.Int64, "b": pl.List(pl.Int64)}
246
247
assert q.collect_schema() == schema
248
assert_frame_equal(q.collect(), new_df_check_schema({"a": [1], "b": [[1]]}, schema))
249
250
# 2x nested
251
lf = pl.LazyFrame({"a": [1], "b": [[[1]]]})
252
253
q = lf.group_by("a").agg(pl.col("b"))
254
schema = {
255
"a": pl.Int64,
256
"b": pl.List(pl.List(pl.List(pl.Int64))),
257
}
258
259
assert q.collect_schema() == schema
260
assert_frame_equal(
261
q.collect(), new_df_check_schema({"a": [1], "b": [[[[1]]]]}, schema)
262
)
263
264
q = lf.group_by("a").agg(pl.col("b").explode())
265
schema = {"a": pl.Int64, "b": pl.List(pl.List(pl.Int64))}
266
267
assert q.collect_schema() == schema
268
assert_frame_equal(
269
q.collect(), new_df_check_schema({"a": [1], "b": [[[1]]]}, schema)
270
)
271
272
q = lf.group_by("a").agg(pl.col("b").explode().explode())
273
schema = {"a": pl.Int64, "b": pl.List(pl.Int64)}
274
275
assert q.collect_schema() == schema
276
assert_frame_equal(q.collect(), new_df_check_schema({"a": [1], "b": [[1]]}, schema))
277
278
279
def test_lazy_nested_function_expr_agg_schema() -> None:
280
q = (
281
pl.LazyFrame({"k": [1, 1, 2]})
282
.group_by(pl.first(), maintain_order=True)
283
.agg(o=pl.int_range(pl.len()).reverse() < 1)
284
)
285
286
assert q.collect_schema() == {"k": pl.Int64, "o": pl.List(pl.Boolean)}
287
assert_frame_equal(
288
q.collect(), pl.DataFrame({"k": [1, 2], "o": [[False, True], [True]]})
289
)
290
291
292
def test_lazy_agg_scalar_return_schema() -> None:
293
q = pl.LazyFrame({"k": [1]}).group_by("k").agg(pl.col("k").null_count().alias("o"))
294
295
schema = {"k": pl.Int64, "o": pl.get_index_type()}
296
assert q.collect_schema() == schema
297
assert_frame_equal(q.collect(), pl.DataFrame({"k": 1, "o": 0}, schema=schema))
298
299
300
def test_lazy_agg_nested_expr_schema() -> None:
301
q = (
302
pl.LazyFrame({"k": [1]})
303
.group_by("k")
304
.agg(
305
(
306
(
307
(pl.col("k").reverse().shuffle() + 1)
308
+ pl.col("k").shuffle().reverse()
309
)
310
.shuffle()
311
.reverse()
312
.sum()
313
* 0
314
).alias("o")
315
)
316
)
317
318
schema = {"k": pl.Int64, "o": pl.Int64}
319
assert q.collect_schema() == schema
320
assert_frame_equal(q.collect(), pl.DataFrame({"k": 1, "o": 0}, schema=schema))
321
322
323
def test_lazy_agg_lit_explode() -> None:
324
q = (
325
pl.LazyFrame({"k": [1]})
326
.group_by("k")
327
.agg(pl.lit(1, dtype=pl.Int64).explode().alias("o"))
328
)
329
330
schema = {"k": pl.Int64, "o": pl.List(pl.Int64)}
331
assert q.collect_schema() == schema
332
assert_frame_equal(q.collect(), pl.DataFrame({"k": 1, "o": [[1]]}, schema=schema)) # type: ignore[arg-type]
333
334
335
@pytest.mark.parametrize(
336
"expr_op", [
337
"approx_n_unique", "arg_max", "arg_min", "bitwise_and", "bitwise_or",
338
"bitwise_xor", "count", "entropy", "first", "has_nulls", "implode", "kurtosis",
339
"last", "len", "lower_bound", "max", "mean", "median", "min", "n_unique", "nan_max",
340
"nan_min", "null_count", "product", "sample", "skew", "std", "sum", "upper_bound",
341
"var"
342
]
343
) # fmt: skip
344
@pytest.mark.parametrize("lhs", [pl.col("b"), pl.lit(1, dtype=pl.Int64).alias("b")])
345
def test_lazy_agg_to_scalar_schema_19752(lhs: pl.Expr, expr_op: str) -> None:
346
op = getattr(pl.Expr, expr_op)
347
348
lf = pl.LazyFrame({"a": 1, "b": 1})
349
350
q = lf.group_by("a").agg(lhs.reverse().pipe(op))
351
assert q.collect_schema() == q.collect().collect_schema()
352
353
q = lf.group_by("a").agg(lhs.shuffle().reverse().pipe(op))
354
355
assert q.collect_schema() == q.collect().collect_schema()
356
357
358
def test_lazy_agg_schema_after_elementwise_19984() -> None:
359
lf = pl.LazyFrame({"a": 1, "b": 1})
360
361
q = lf.group_by("a").agg(pl.col("b").item().fill_null(0))
362
assert q.collect_schema() == q.collect().collect_schema()
363
364
q = lf.group_by("a").agg(pl.col("b").item().fill_null(0).fill_null(0))
365
assert q.collect_schema() == q.collect().collect_schema()
366
367
q = lf.group_by("a").agg(pl.col("b").item() + 1)
368
assert q.collect_schema() == q.collect().collect_schema()
369
370
q = lf.group_by("a").agg(1 + pl.col("b").item())
371
assert q.collect_schema() == q.collect().collect_schema()
372
373
374
@pytest.mark.parametrize(
375
"expr", [pl.col("b"), pl.col("b").sum(), pl.col("b").reverse()]
376
)
377
@pytest.mark.parametrize("mapping_strategy", ["explode", "join", "group_to_rows"])
378
def test_lazy_window_schema(expr: pl.Expr, mapping_strategy: str) -> None:
379
q = pl.LazyFrame({"a": 1, "b": 1}).select(
380
expr.over("a", mapping_strategy=mapping_strategy) # type: ignore[arg-type]
381
)
382
383
assert q.collect_schema() == q.collect().collect_schema()
384
385
386
def test_lazy_explode_schema() -> None:
387
lf = pl.LazyFrame({"k": [1], "x": pl.Series([[1]], dtype=pl.Array(pl.Int64, 1))})
388
389
q = lf.select(pl.col("x").explode())
390
assert q.collect_schema() == {"x": pl.Int64}
391
392
q = lf.select(pl.col("x").arr.explode())
393
assert q.collect_schema() == {"x": pl.Int64}
394
395
lf = pl.LazyFrame({"k": [1], "x": pl.Series([[1]], dtype=pl.List(pl.Int64))})
396
397
q = lf.select(pl.col("x").explode())
398
assert q.collect_schema() == {"x": pl.Int64}
399
400
q = lf.select(pl.col("x").list.explode())
401
assert q.collect_schema() == {"x": pl.Int64}
402
403
# `LazyFrame.explode()` goes through a different codepath than `Expr.expode`
404
lf = pl.LazyFrame().with_columns(
405
pl.Series([[1]], dtype=pl.List(pl.Int64)).alias("list"),
406
pl.Series([[1]], dtype=pl.Array(pl.Int64, 1)).alias("array"),
407
)
408
409
q = lf.explode("*")
410
assert q.collect_schema() == {"list": pl.Int64, "array": pl.Int64}
411
412
q = lf.explode("list")
413
assert q.collect_schema() == {"list": pl.Int64, "array": pl.Array(pl.Int64, 1)}
414
415
416
def test_raise_subnodes_18787() -> None:
417
df = pl.DataFrame({"a": [1], "b": [2]})
418
419
with pytest.raises(pl.exceptions.ColumnNotFoundError):
420
(
421
df.select(pl.struct(pl.all())).select(
422
pl.first().struct.field("a", "b").filter(pl.col("foo") == 1)
423
)
424
)
425
426
427
def test_scalar_agg_schema_20044() -> None:
428
assert (
429
pl.DataFrame(None, schema={"a": pl.Int64, "b": pl.String, "c": pl.String})
430
.with_columns(d=pl.col("a").max())
431
.group_by("c")
432
.agg(pl.col("d").mean())
433
).schema == pl.Schema([("c", pl.String), ("d", pl.Float64)])
434
435
436
@pytest.mark.parametrize(
437
"df",
438
[
439
pl.DataFrame({"a": [None, True, False], "b": 3 * [128]}),
440
pl.DataFrame(
441
{"a": [[None, True, False]], "b": [3 * [128]]},
442
schema={"a": pl.Array(pl.Boolean, 3), "b": pl.Array(pl.Int64, 3)},
443
),
444
pl.DataFrame(
445
{"a": [[None, True, False]], "b": [3 * [128]]},
446
schema={"a": pl.List(pl.Boolean), "b": pl.List(pl.Int64)},
447
),
448
],
449
)
450
def test_div_collect_schema_matches_23993(df: pl.DataFrame) -> None:
451
q = df.lazy().select(pl.col("a") / pl.col("b"))
452
expected = q.collect().schema
453
actual = q.collect_schema()
454
assert actual == expected
455
456