Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/lazyframe/test_collect_all.py
8450 views
1
from pathlib import Path
2
from typing import cast
3
4
import pytest
5
6
import polars as pl
7
from polars.testing import assert_frame_equal
8
9
10
def test_collect_all_type_coercion_21805() -> None:
11
df = pl.LazyFrame({"A": [1.0, 2.0]})
12
df = df.with_columns(pl.col("A").shift().fill_null(2))
13
assert pl.collect_all([df])[0]["A"].to_list() == [2.0, 1.0]
14
15
16
@pytest.mark.parametrize("optimizations", [pl.QueryOptFlags(), pl.QueryOptFlags.none()])
17
def test_collect_all(df: pl.DataFrame, optimizations: pl.QueryOptFlags) -> None:
18
lf1 = df.lazy().select(pl.col("int").sum())
19
lf2 = df.lazy().select((pl.col("floats") * 2).sum())
20
out = pl.collect_all([lf1, lf2], optimizations=optimizations)
21
assert cast("int", out[0].item()) == 6
22
assert cast("float", out[1].item()) == 12.0
23
24
25
def test_collect_all_issue_26097(tmp_path: Path) -> None:
26
data = pl.DataFrame({"A": [1]})
27
tmp_file = tmp_path / "polars-bug-repr.parquet"
28
data.write_parquet(tmp_file)
29
30
df = pl.scan_parquet(tmp_file).select([pl.col("A")])
31
32
dummy_df = pl.DataFrame({"v": [1]}).lazy().select(pl.len())
33
results = pl.collect_all([dummy_df, df])
34
35
expected = pl.DataFrame({"A": [1]})
36
assert_frame_equal(results[1], expected)
37
38
Path(tmp_file).unlink()
39
40
41
def test_collect_all_groupby_lazy_sink_issue_26296(tmp_path: Path) -> None:
42
df = pl.DataFrame({"g": ["A"], "v": [1]})
43
result = df.lazy().group_by("g").agg(pl.col("v").sum())
44
45
tmp_file = tmp_path / "bug.parquet"
46
sink = result.sink_parquet(tmp_file, lazy=True)
47
other = result.select(pl.lit(1))
48
49
# Should not raise ColumnNotFoundError: v
50
results = pl.collect_all([other, sink])
51
52
expected_other = pl.DataFrame({"literal": [1]}, schema={"literal": pl.Int32})
53
assert_frame_equal(results[0], expected_other)
54
55
expected_sink = pl.DataFrame({"g": ["A"], "v": [1]})
56
assert_frame_equal(pl.read_parquet(tmp_file), expected_sink)
57
58